from data import BipartiteDatasetRL from torch_geometric.loader import DataLoader import torch import torch.nn.functional as F from torchrl.modules import MaskedCategorical from replay import Transition import numpy as np def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95): advantages = torch.zeros_like(rewards, device='mps') returns_to_go = torch.zeros_like(rewards, device='mps') last_gae = 0 # GAE for the final timestep in an episode last_return = 0 # return for t in reversed(range(len(rewards))): delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage returns_to_go[t] = last_return = rewards[t] + gamma*last_return return advantages, returns_to_go def compute_discounted_returns(rewards, gamma=0.99): rewards = rewards.float() discounts = gamma ** torch.arange(len(rewards)).float() discounted_rewards = rewards * discounts discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0]) return discounted_returns / discounts def decayed_learning_rate(step): initial_learning_rate = 0.05 decay_rate = 0.1 decay_steps = 100 return initial_learning_rate * (decay_rate ** (step / decay_steps)) def GAE(actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer, step): actor.train() critic_main.train() transitions = replay_buffer.memory batch_size1 = len(transitions) batch_size2 = 16 batch = Transition(*zip(*transitions)) dataset = BipartiteDatasetRL(*batch) dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) for SASR in dataloader1: graph, action, reward, done, _ = SASR # reward = (reward - reward.mean()) / (1e-8 + reward.std()) with torch.no_grad(): # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) values = torch.cat((values, torch.tensor([0], device='cpu')), dim=0) advantages, returns = compute_gae(reward, values, done) advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std()) indices = np.random.permutation(len(dataset)) shuffled_dataset = torch.utils.data.Subset(dataset, indices) dataloader2 = DataLoader(shuffled_dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) batch_idx = 0 # tau = 0.005 total_actor_loss, total_critic_loss = 0, 0 mean_entropy = [] mean_kl_div = [] for batch_i, SASR in enumerate(dataloader2): graph, action, reward, done, batch_old_log_probs = SASR batch_size_i = len(action) batch_start = batch_idx batch_end = batch_start + batch_size_i actor_opt.zero_grad() critic_opt.zero_grad() # formatting for indexing in batched data num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)] node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = graph.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = graph.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) # Estimate policy gradient logits = actor( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v ) logits_per_graph = torch.split(logits, num_nodes_per_graph) log_probs = [] entropy_arr = [] for i in range(len(logits_per_graph)): logits_i = logits_per_graph[i] candidates_i = candidates_per_graph[i] action_i = action[i] # Create mask for valid actions mask_i = torch.zeros_like(logits_i, dtype=torch.bool) mask_i[candidates_i] = True # Create action distribution action_dist = MaskedCategorical(logits=logits_i, mask=mask_i) log_prob_i = action_dist.log_prob(action_i) log_probs.append(log_prob_i) entropy_arr.append(action_dist.entropy()) log_probs = torch.stack(log_probs) # Find entropy of policy and store ent = torch.stack(entropy_arr).mean() mean_entropy.append(ent.detach().item() * len(entropy_arr)) # Find KL divergence between old policy and updated approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item() mean_kl_div.append(approx_kl) batch_advantages = advantages[shuffled_dataset.indices[batch_start:batch_end]] eps = 0.3 # ratios = torch.exp(log_probs - batch_old_log_probs) # surr1 = ratios * batch_advantages # surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages # actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent actor_loss = -torch.mean(log_probs*batch_advantages) - decayed_learning_rate(step)*ent # actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end]) # actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]]) # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) # critic_loss = F.mse_loss(values, returns[batch_start:batch_end]) critic_loss = F.mse_loss(values, returns[shuffled_dataset.indices[batch_start:batch_end]]) actor_loss.backward() critic_loss.backward() actor_opt.step() critic_opt.step() total_critic_loss += critic_loss.item() total_actor_loss += actor_loss.item() # for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()): # target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data) batch_idx += batch_size_i return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2), advantages.mean().item(), sum(mean_entropy)/batch_size1, sum(mean_kl_div)/batch_size1 def REINFORCE(actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer, step): actor.train() critic_main.train() transitions = replay_buffer.memory batch_size = 16 batch = Transition(*zip(*transitions)) dataset = BipartiteDatasetRL(*batch) rewards = torch.tensor(batch.reward) returns = compute_discounted_returns(rewards) indices = np.random.permutation(len(dataset)) shuffled_dataset = torch.utils.data.Subset(dataset, indices) dataloader = DataLoader(shuffled_dataset, batch_size=batch_size, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) batch_idx = 0 # tau = 0.005 total_actor_loss, total_critic_loss = 0, 0 mean_entropy = [] mean_kl_div = [] mean_advantages = [] for batch_i, SASR in enumerate(dataloader): graph, action, reward, done, batch_old_log_probs = SASR if len(action) <= 1: continue batch_size_i = len(action) batch_start = batch_idx batch_end = batch_start + batch_size_i actor_opt.zero_grad() critic_opt.zero_grad() # formatting for indexing in batched data num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)] node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = graph.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = graph.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) # Estimate policy gradient logits = actor( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v ) logits_per_graph = torch.split(logits, num_nodes_per_graph) log_probs = [] entropy_arr = [] for i in range(len(logits_per_graph)): logits_i = logits_per_graph[i] candidates_i = candidates_per_graph[i] action_i = action[i] # Create mask for valid actions mask_i = torch.zeros_like(logits_i, dtype=torch.bool) mask_i[candidates_i] = True # Create action distribution action_dist = MaskedCategorical(logits=logits_i, mask=mask_i) log_prob_i = action_dist.log_prob(action_i) log_probs.append(log_prob_i) entropy_arr.append(action_dist.entropy()) log_probs = torch.stack(log_probs) # Find entropy of policy and store ent = torch.stack(entropy_arr).mean() mean_entropy.append(ent.detach().item() * len(entropy_arr)) # Find KL divergence between old policy and updated approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item() mean_kl_div.append(approx_kl) batch_returns = returns[shuffled_dataset.indices[batch_start:batch_end]] eps = 0.3 # ratios = torch.exp(log_probs - batch_old_log_probs) # surr1 = ratios * batch_advantages # surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages # actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent # actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end]) # actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]]) # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) batch_advantages = (batch_returns - values.detach()) batch_advantages = (batch_advantages - batch_advantages.mean()) / (1e-8 + batch_advantages.std()) mean_advantages.append(batch_advantages.sum().item()) actor_loss = -torch.mean(log_probs*batch_advantages) # critic_loss = F.mse_loss(values, returns[batch_start:batch_end]) critic_loss = F.mse_loss(values, batch_returns) actor_loss.backward() critic_loss.backward() actor_opt.step() critic_opt.step() total_critic_loss += critic_loss.item() total_actor_loss += actor_loss.item() # for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()): # target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data) batch_idx += batch_size_i return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader), np.sum(mean_advantages)/(len(transitions)), sum(mean_entropy)/batch_size, sum(mean_kl_div)/batch_size