from data import BipartiteDatasetRL
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torchrl.modules import MaskedCategorical
from replay import Transition
import numpy as np
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
advantages = torch.zeros_like(rewards, device='mps')
returns_to_go = torch.zeros_like(rewards, device='mps')
last_gae = 0 # GAE for the final timestep in an episode
last_return = 0 # return
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error
advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage
returns_to_go[t] = last_return = rewards[t] + gamma*last_return
return advantages, returns_to_go
def compute_discounted_returns(rewards, gamma=0.99):
rewards = rewards.float()
discounts = gamma ** torch.arange(len(rewards)).float()
discounted_rewards = rewards * discounts
discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0])
return discounted_returns / discounts
def decayed_learning_rate(step):
initial_learning_rate = 0.05
decay_rate = 0.1
decay_steps = 100
return initial_learning_rate * (decay_rate ** (step / decay_steps))
def GAE(actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer, step):
actor.train()
critic_main.train()
transitions = replay_buffer.memory
batch_size1 = len(transitions)
batch_size2 = 16
batch = Transition(*zip(*transitions))
dataset = BipartiteDatasetRL(*batch)
dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
for SASR in dataloader1:
graph, action, reward, done, _ = SASR
# reward = (reward - reward.mean()) / (1e-8 + reward.std())
with torch.no_grad():
# get state values of batch
values = critic_main(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v,
graph.x_c_batch,
graph.x_v_batch,
graph.candidates
)
values = torch.cat((values, torch.tensor([0], device='cpu')), dim=0)
advantages, returns = compute_gae(reward, values, done)
advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std())
indices = np.random.permutation(len(dataset))
shuffled_dataset = torch.utils.data.Subset(dataset, indices)
dataloader2 = DataLoader(shuffled_dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
batch_idx = 0
# tau = 0.005
total_actor_loss, total_critic_loss = 0, 0
mean_entropy = []
mean_kl_div = []
for batch_i, SASR in enumerate(dataloader2):
graph, action, reward, done, batch_old_log_probs = SASR
batch_size_i = len(action)
batch_start = batch_idx
batch_end = batch_start + batch_size_i
actor_opt.zero_grad()
critic_opt.zero_grad()
# formatting for indexing in batched data
num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
# Split candidates per graph and adjust indices
candidate_counts = graph.nb_candidates.tolist()
candidates_per_graph = []
for i, num_candidates in enumerate(candidate_counts):
start = sum(candidate_counts[:i])
end = start + num_candidates
candidates_i = graph.candidates[start:end] - node_offsets[i]
candidates_per_graph.append(candidates_i)
# Estimate policy gradient
logits = actor(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v
)
logits_per_graph = torch.split(logits, num_nodes_per_graph)
log_probs = []
entropy_arr = []
for i in range(len(logits_per_graph)):
logits_i = logits_per_graph[i]
candidates_i = candidates_per_graph[i]
action_i = action[i]
# Create mask for valid actions
mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
mask_i[candidates_i] = True
# Create action distribution
action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
log_prob_i = action_dist.log_prob(action_i)
log_probs.append(log_prob_i)
entropy_arr.append(action_dist.entropy())
log_probs = torch.stack(log_probs)
# Find entropy of policy and store
ent = torch.stack(entropy_arr).mean()
mean_entropy.append(ent.detach().item() * len(entropy_arr))
# Find KL divergence between old policy and updated
approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item()
mean_kl_div.append(approx_kl)
batch_advantages = advantages[shuffled_dataset.indices[batch_start:batch_end]]
eps = 0.3
# ratios = torch.exp(log_probs - batch_old_log_probs)
# surr1 = ratios * batch_advantages
# surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages
# actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent
actor_loss = -torch.mean(log_probs*batch_advantages) - decayed_learning_rate(step)*ent
# actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end])
# actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]])
# get state values of batch
values = critic_main(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v,
graph.x_c_batch,
graph.x_v_batch,
graph.candidates
)
# critic_loss = F.mse_loss(values, returns[batch_start:batch_end])
critic_loss = F.mse_loss(values, returns[shuffled_dataset.indices[batch_start:batch_end]])
actor_loss.backward()
critic_loss.backward()
actor_opt.step()
critic_opt.step()
total_critic_loss += critic_loss.item()
total_actor_loss += actor_loss.item()
# for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()):
# target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data)
batch_idx += batch_size_i
return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2), advantages.mean().item(), sum(mean_entropy)/batch_size1, sum(mean_kl_div)/batch_size1
def REINFORCE(actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer, step):
actor.train()
critic_main.train()
transitions = replay_buffer.memory
batch_size = 16
batch = Transition(*zip(*transitions))
dataset = BipartiteDatasetRL(*batch)
rewards = torch.tensor(batch.reward)
returns = compute_discounted_returns(rewards)
indices = np.random.permutation(len(dataset))
shuffled_dataset = torch.utils.data.Subset(dataset, indices)
dataloader = DataLoader(shuffled_dataset, batch_size=batch_size, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates'])
batch_idx = 0
# tau = 0.005
total_actor_loss, total_critic_loss = 0, 0
mean_entropy = []
mean_kl_div = []
mean_advantages = []
for batch_i, SASR in enumerate(dataloader):
graph, action, reward, done, batch_old_log_probs = SASR
if len(action) <= 1:
continue
batch_size_i = len(action)
batch_start = batch_idx
batch_end = batch_start + batch_size_i
actor_opt.zero_grad()
critic_opt.zero_grad()
# formatting for indexing in batched data
num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)]
node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist()
# Split candidates per graph and adjust indices
candidate_counts = graph.nb_candidates.tolist()
candidates_per_graph = []
for i, num_candidates in enumerate(candidate_counts):
start = sum(candidate_counts[:i])
end = start + num_candidates
candidates_i = graph.candidates[start:end] - node_offsets[i]
candidates_per_graph.append(candidates_i)
# Estimate policy gradient
logits = actor(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v
)
logits_per_graph = torch.split(logits, num_nodes_per_graph)
log_probs = []
entropy_arr = []
for i in range(len(logits_per_graph)):
logits_i = logits_per_graph[i]
candidates_i = candidates_per_graph[i]
action_i = action[i]
# Create mask for valid actions
mask_i = torch.zeros_like(logits_i, dtype=torch.bool)
mask_i[candidates_i] = True
# Create action distribution
action_dist = MaskedCategorical(logits=logits_i, mask=mask_i)
log_prob_i = action_dist.log_prob(action_i)
log_probs.append(log_prob_i)
entropy_arr.append(action_dist.entropy())
log_probs = torch.stack(log_probs)
# Find entropy of policy and store
ent = torch.stack(entropy_arr).mean()
mean_entropy.append(ent.detach().item() * len(entropy_arr))
# Find KL divergence between old policy and updated
approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item()
mean_kl_div.append(approx_kl)
batch_returns = returns[shuffled_dataset.indices[batch_start:batch_end]]
eps = 0.3
# ratios = torch.exp(log_probs - batch_old_log_probs)
# surr1 = ratios * batch_advantages
# surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages
# actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent
# actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end])
# actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]])
# get state values of batch
values = critic_main(
graph.x_c,
graph.edge_index,
graph.edge_features,
graph.x_v,
graph.x_c_batch,
graph.x_v_batch,
graph.candidates
)
batch_advantages = (batch_returns - values.detach())
batch_advantages = (batch_advantages - batch_advantages.mean()) / (1e-8 + batch_advantages.std())
mean_advantages.append(batch_advantages.sum().item())
actor_loss = -torch.mean(log_probs*batch_advantages)
# critic_loss = F.mse_loss(values, returns[batch_start:batch_end])
critic_loss = F.mse_loss(values, batch_returns)
actor_loss.backward()
critic_loss.backward()
actor_opt.step()
critic_opt.step()
total_critic_loss += critic_loss.item()
total_actor_loss += actor_loss.item()
# for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()):
# target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data)
batch_idx += batch_size_i
return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader), np.sum(mean_advantages)/(len(transitions)), sum(mean_entropy)/batch_size, sum(mean_kl_div)/batch_size