from torch_geometric.loader import DataLoader import torch import torch.nn.functional as F from torchrl.modules import MaskedCategorical import torch.distributions as D import numpy as np from submissions.dual.training.utils import Transition, GraphDataset def compute_gae(rewards, values, dones, gamma=1.0, lam=0.95): values = torch.cat([values, torch.zeros(1, device=values.device)], dim=0) advantages = torch.zeros_like(rewards, device='cpu') returns_to_go = torch.zeros_like(rewards, device='cpu') last_gae = 0 # GAE for the final timestep in an episode last_return = 0 # return for t in reversed(range(len(rewards))): delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage returns_to_go[t] = last_return = rewards[t] + gamma*last_return # returns_to_go = advantages + values[:-1] return advantages, returns_to_go def compute_discounted_returns(rewards, gamma=0.99): rewards = rewards.float() discounts = gamma ** torch.arange(len(rewards)).float() discounted_rewards = rewards * discounts discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0]) return discounted_returns / discounts def decayed_learning_rate(step): initial_learning_rate = 0.05 decay_rate = 0.1 decay_steps = 100 return initial_learning_rate * (decay_rate ** (step / decay_steps)) def GAE(actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer, step): actor.train() critic_main.train() transitions = replay_buffer.memory batch_size1 = len(transitions) batch_size2 = 16 batch = Transition(*zip(*transitions)) dataset = BipartiteDatasetRL(*batch) dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) for SASR in dataloader1: graph, action, reward, done, _ = SASR # reward = (reward - reward.mean()) / (1e-8 + reward.std()) with torch.no_grad(): # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) values = torch.cat((values, torch.tensor([0], device='cpu')), dim=0) advantages, returns = compute_gae(reward, values, done) advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std()) indices = np.random.permutation(len(dataset)) shuffled_dataset = torch.utils.data.Subset(dataset, indices) dataloader2 = DataLoader(shuffled_dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) batch_idx = 0 # tau = 0.005 total_actor_loss, total_critic_loss = 0, 0 mean_entropy = [] mean_kl_div = [] for batch_i, SASR in enumerate(dataloader2): graph, action, reward, done, batch_old_log_probs = SASR batch_size_i = len(action) batch_start = batch_idx batch_end = batch_start + batch_size_i actor_opt.zero_grad() critic_opt.zero_grad() # formatting for indexing in batched data num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)] node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = graph.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = graph.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) # Estimate policy gradient logits = actor( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v ) logits_per_graph = torch.split(logits, num_nodes_per_graph) log_probs = [] entropy_arr = [] for i in range(len(logits_per_graph)): logits_i = logits_per_graph[i] candidates_i = candidates_per_graph[i] action_i = action[i] # Create mask for valid actions mask_i = torch.zeros_like(logits_i, dtype=torch.bool) mask_i[candidates_i] = True # Create action distribution action_dist = MaskedCategorical(logits=logits_i, mask=mask_i) log_prob_i = action_dist.log_prob(action_i) log_probs.append(log_prob_i) entropy_arr.append(action_dist.entropy()) log_probs = torch.stack(log_probs) # Find entropy of policy and store ent = torch.stack(entropy_arr).mean() mean_entropy.append(ent.detach().item() * len(entropy_arr)) # Find KL divergence between old policy and updated approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item() mean_kl_div.append(approx_kl) batch_advantages = advantages[shuffled_dataset.indices[batch_start:batch_end]] eps = 0.3 # ratios = torch.exp(log_probs - batch_old_log_probs) # surr1 = ratios * batch_advantages # surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages # actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent actor_loss = -torch.mean(log_probs*batch_advantages) - decayed_learning_rate(step)*ent # actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end]) # actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]]) # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) # critic_loss = F.mse_loss(values, returns[batch_start:batch_end]) critic_loss = F.mse_loss(values, returns[shuffled_dataset.indices[batch_start:batch_end]]) actor_loss.backward() critic_loss.backward() actor_opt.step() critic_opt.step() total_critic_loss += critic_loss.item() total_actor_loss += actor_loss.item() # for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()): # target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data) batch_idx += batch_size_i return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2), advantages.mean().item(), sum(mean_entropy)/batch_size1, sum(mean_kl_div)/batch_size1 def REINFORCE(actor: torch.nn.Module, optimizer: torch.optim.Optimizer, replay_buffer): actor.train() transitions = replay_buffer.memory batch_size = 64 batch = Transition(*zip(*transitions)) dataset = BipartiteDatasetRL(*batch) rewards = torch.tensor(batch.reward) returns = compute_discounted_returns(rewards) indices = np.random.permutation(len(dataset)) shuffled_dataset = torch.utils.data.Subset(dataset, indices) dataloader = DataLoader(shuffled_dataset, batch_size=batch_size, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) batch_idx = 0 # tau = 0.005 total_actor_loss, total_critic_loss = 0, 0 mean_entropy = [] mean_kl_div = [] mean_advantages = [] for batch_i, SASR in enumerate(dataloader): graph, action, reward, done, batch_old_log_probs = SASR if len(action) <= 1: continue batch_size_i = len(action) batch_start = batch_idx batch_end = batch_start + batch_size_i actor_opt.zero_grad() critic_opt.zero_grad() # formatting for indexing in batched data num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)] node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = graph.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = graph.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) # Estimate policy gradient logits = actor( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v ) logits_per_graph = torch.split(logits, num_nodes_per_graph) log_probs = [] entropy_arr = [] for i in range(len(logits_per_graph)): logits_i = logits_per_graph[i] candidates_i = candidates_per_graph[i] action_i = action[i] # Create mask for valid actions mask_i = torch.zeros_like(logits_i, dtype=torch.bool) mask_i[candidates_i] = True # Create action distribution action_dist = MaskedCategorical(logits=logits_i, mask=mask_i) log_prob_i = action_dist.log_prob(action_i) log_probs.append(log_prob_i) entropy_arr.append(action_dist.entropy()) log_probs = torch.stack(log_probs) # Find entropy of policy and store ent = torch.stack(entropy_arr).mean() mean_entropy.append(ent.detach().item() * len(entropy_arr)) # Find KL divergence between old policy and updated approx_kl = (torch.exp(batch_old_log_probs)*(batch_old_log_probs - log_probs.detach())).sum().item() mean_kl_div.append(approx_kl) batch_returns = returns[shuffled_dataset.indices[batch_start:batch_end]] eps = 0.3 # ratios = torch.exp(log_probs - batch_old_log_probs) # surr1 = ratios * batch_advantages # surr2 = torch.clamp(ratios, 1 - eps, 1 + eps)*batch_advantages # actor_loss = -torch.min(surr1, surr2).mean() - decayed_learning_rate(step)*ent # actor_loss = -torch.mean(log_probs * advantages[batch_start:batch_end]) # actor_loss = -torch.sum(log_probs * advantages[shuffled_dataset.indices[batch_start:batch_end]]) # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) batch_advantages = (batch_returns - values.detach()) batch_advantages = (batch_advantages - batch_advantages.mean()) / (1e-8 + batch_advantages.std()) mean_advantages.append(batch_advantages.sum().item()) actor_loss = -torch.mean(log_probs*batch_advantages) # critic_loss = F.mse_loss(values, returns[batch_start:batch_end]) critic_loss = F.mse_loss(values, batch_returns) actor_loss.backward() critic_loss.backward() actor_opt.step() critic_opt.step() total_critic_loss += critic_loss.item() total_actor_loss += actor_loss.item() # for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()): # target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data) batch_idx += batch_size_i return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader), np.sum(mean_advantages)/(len(transitions)), sum(mean_entropy)/batch_size, sum(mean_kl_div)/batch_size def off_PAC(actor, critic, target_critic, actor_optimizer, critic_optimizer, buffer, tau=0.01): batch = Transition(*zip(*buffer.memory)) dataset = GraphDataset(*batch) dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, follow_batch=['constraint_features', 'variable_features', 'candidates']) replay = next(iter(dataloader)) states, log_probs, rewards, dones = replay with torch.no_grad(): target_values = target_critic( states.constraint_features, states.edge_index, states.edge_attr, states.variable_features, states.variable_features_batch, ) # Compute Advantages advantages, returns_to_go = compute_gae(rewards, target_values.detach(), dones) advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std()) # returns_to_go = (returns_to_go - returns_to_go.mean()) / (1e-8 + returns_to_go.std()) # formatting for indexing in batched data num_nodes_per_graph = states.num_graphs*[int(states.variable_features.shape[0]/states.num_graphs)] node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = states.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = states.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) learn_logits = actor( states.constraint_features, states.edge_index, states.edge_attr, states.variable_features, ) learn_logits = torch.stack(torch.split(learn_logits, num_nodes_per_graph)) learn_log_probs = [] for branch in range(len(candidates_per_graph)): learn_logits_i = learn_logits[branch, :] candidates_i = candidates_per_graph[branch] learn_logits_i = learn_logits_i[candidates_i] learn_dist = D.Categorical(logits=learn_logits_i) action_i = states.candidate_choices[branch] learn_log_probs.append(learn_dist.log_prob(action_i)) learn_log_probs = torch.stack(learn_log_probs) # Compute importance sampling weights rho = torch.exp(learn_log_probs - log_probs) eps_clip = 0.25 surr1 = rho * advantages surr2 = torch.clamp(rho, 1 - eps_clip, 1 + eps_clip) * advantages actor_loss = -torch.min(surr1, surr2).mean() # Compute critic loss values = critic( states.constraint_features, states.edge_index, states.edge_attr, states.variable_features, states.variable_features_batch, ) # critic_loss = F.mse_loss(values, returns_to_go) critic_loss = F.huber_loss(values, returns_to_go) # Compute actor loss # actor_loss = -torch.mean(rho*learn_log_probs*advantages) actor_optimizer.zero_grad() critic_optimizer.zero_grad() # Update actor and critic jointly actor_loss.backward() critic_loss.backward() # torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0) actor_optimizer.step() critic_optimizer.step() for target_param, param in zip(target_critic.parameters(), critic.parameters()): target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data) return actor_loss.item(), critic_loss.item()