import ecole import ecole.environment import ecole.instance import ecole.observation import ecole.reward import ecole.typing import hydra from omegaconf import DictConfig from data import BipartiteDatasetRL from torch_geometric.loader import DataLoader from model import GNNActor, GNNCritic import random import torch import torch.nn.functional as F import numpy as np import os from replay import Replay, Transition from torchrl.modules import MaskedCategorical from utils import create_folder, to_state, to_tensor import matplotlib.pyplot as plt def compute_discounted_returns(rewards, gamma): rewards = rewards.float() discounts = gamma ** torch.arange(len(rewards)).float() discounted_rewards = rewards * discounts discounted_returns = torch.flip(torch.cumsum(torch.flip(discounted_rewards, [0]), dim=0), [0]) return discounted_returns / discounts ## DOES NOT TRAIN WELL - needs fixing ## def train_actor_critic(cfg: DictConfig, actor: torch.nn.Module, critic: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer): transitions = replay_buffer.memory nt = len(transitions) batch_size = 128 # Gather transition information into tensors batch = Transition(*zip(*transitions)) discounted_return_batch = compute_discounted_returns(torch.stack(batch.reward), cfg.training.gamma) dataset = BipartiteDatasetRL(batch.state, discounted_return_batch, batch.action_set, batch.action) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) actor.train() critic.train() total_actor_loss = 0 total_critic_loss = 0 for sub_batch in dataloader: # Need to fix rewards graph, discounted_return, action = sub_batch logits = actor( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v ) # batch_indices = graph.batch # shape [num_nodes_in_batch], indicates graph indices # num_nodes_per_graph = torch.bincount(batch_indices).tolist() num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)] logits_per_graph = torch.split(logits, num_nodes_per_graph) # Node offsets to adjust indices node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = graph.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = graph.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) log_probs = [] for i in range(len(logits_per_graph)): logits_i = logits_per_graph[i] candidates_i = candidates_per_graph[i] action_i = action[i] # Create mask for valid actions mask_i = torch.zeros_like(logits_i, dtype=torch.bool) mask_i[candidates_i] = True # Create action distribution action_dist = MaskedCategorical(logits=logits_i, mask=mask_i) log_prob_i = action_dist.log_prob(action_i) log_probs.append(log_prob_i) log_probs = torch.stack(log_probs) # state_values = critic( # graph.x_c, # graph.edge_index, # graph.edge_features, # graph.x_v # ) # state_values_per_graph = torch.split(state_values, num_nodes_per_graph) # outputs = [] # for i in range(len(state_values_per_graph)): # state_values_i = state_values_per_graph[i] # candidates_i = candidates_per_graph[i] # candidate_state_values_i = state_values_i[candidates_i] # output_i = candidate_state_values_i.mean() # outputs.append(output_i) # output = torch.stack(outputs) # advantage = discounted_return - output.detach() advantage = discounted_return if advantage.numel() != 1: advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) # entropy = -torch.sum(torch.exp(log_probs)*log_probs) actor_loss = -torch.mean(log_probs * advantage) #- 0.01*entropy actor_opt.zero_grad() actor_loss.backward() actor_opt.step() # critic_loss = torch.nn.functional.mse_loss(output, discounted_return) # critic_opt.zero_grad() # critic_loss.backward() # critic_opt.step() # total_critic_loss += critic_loss.item() total_actor_loss += actor_loss.item() return total_actor_loss/len(dataloader), 0.0 # return total_actor_loss/len(dataloader), total_critic_loss/len(dataloader) def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95): advantages = torch.zeros_like(rewards) returns_to_go = torch.zeros_like(rewards) last_gae = 0 # GAE for the final timestep in an episode last_return = 0 # return for t in reversed(range(len(rewards))): delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] # td-error advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae # gae advantage returns_to_go[t] = last_return = rewards[t] + gamma*last_return return advantages, returns_to_go def train_GAE(cfg: DictConfig, actor: torch.nn.Module, critic_main: torch.nn.Module, critic_target: torch.nn.Module, actor_opt: torch.optim.Optimizer, critic_opt: torch.optim.Optimizer, replay_buffer): actor.train() critic_main.train() transitions = replay_buffer.memory batch_size1 = len(transitions) batch_size2 = 16 batch = Transition(*zip(*transitions)) dataset = BipartiteDatasetRL(*batch) dataloader1 = DataLoader(dataset, batch_size=batch_size1, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) for SASR in dataloader1: graph, action, reward, done = SASR with torch.no_grad(): # get state values of batch values = critic_target( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) values = torch.cat((values, torch.tensor([0])), dim=0) advantages, returns = compute_gae(reward, values, done) if len(advantages) <= 1: return torch.nan, torch.nan # advantages = (advantages - advantages.mean()) / (1e-8 + advantages.std()) dataloader2 = DataLoader(dataset, batch_size=batch_size2, shuffle=False, follow_batch=['x_c', 'x_v', 'candidates']) batch_idx = 0 tau = 0.005 total_actor_loss, total_critic_loss = 0, 0 for SASR in dataloader2: graph, action, reward, done = SASR batch_size_i = len(action) batch_start = batch_idx batch_end = batch_start + batch_size_i actor_opt.zero_grad() critic_opt.zero_grad() # formatting for indexing in batched data num_nodes_per_graph = graph.num_graphs*[int(graph.x_v.shape[0]/graph.num_graphs)] node_offsets = [0] + np.cumsum(num_nodes_per_graph[:-1]).tolist() # Split candidates per graph and adjust indices candidate_counts = graph.nb_candidates.tolist() candidates_per_graph = [] for i, num_candidates in enumerate(candidate_counts): start = sum(candidate_counts[:i]) end = start + num_candidates candidates_i = graph.candidates[start:end] - node_offsets[i] candidates_per_graph.append(candidates_i) # Estimate policy gradient logits = actor( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v ) logits_per_graph = torch.split(logits, num_nodes_per_graph) log_probs = [] for i in range(len(logits_per_graph)): logits_i = logits_per_graph[i] candidates_i = candidates_per_graph[i] action_i = action[i] # Create mask for valid actions mask_i = torch.zeros_like(logits_i, dtype=torch.bool) mask_i[candidates_i] = True # Create action distribution action_dist = MaskedCategorical(logits=logits_i, mask=mask_i) log_prob_i = action_dist.log_prob(action_i) log_probs.append(log_prob_i) log_probs = torch.stack(log_probs) actor_loss = -torch.sum(log_probs * advantages[batch_start:batch_end]) # get state values of batch values = critic_main( graph.x_c, graph.edge_index, graph.edge_features, graph.x_v, graph.x_c_batch, graph.x_v_batch, graph.candidates ) critic_loss = torch.nn.functional.huber_loss(values, returns[batch_start:batch_end]) actor_loss.backward() torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=1.0) critic_loss.backward() torch.nn.utils.clip_grad_norm_(critic_main.parameters(), max_norm=1.0) actor_opt.step() critic_opt.step() total_critic_loss += critic_loss.item() total_actor_loss += actor_loss.item() for target_param, main_param in zip(critic_target.parameters(), critic_main.parameters()): target_param.data.copy_(tau * main_param.data + (1 - tau) * target_param.data) batch_idx += batch_size_i return total_actor_loss/len(dataloader2), total_critic_loss/len(dataloader2) @hydra.main(version_base=None, config_path="conf", config_name="config") def main(cfg : DictConfig) -> None: # print(OmegaConf.to_yaml(cfg)) torch.manual_seed(cfg.training.seed) random.seed(cfg.training.seed) np.random.seed(cfg.training.seed) generator = torch.Generator() generator.manual_seed(cfg.training.seed) # define problem instance generator if cfg.problem == "set_cover": gen = ecole.instance.SetCoverGenerator( n_rows=cfg.n_rows, n_cols=cfg.n_cols, density=cfg.density, ) gen.seed(cfg.training.seed) # define the observation function observation_functions = ( ecole.observation.NodeBipartite() ) # scip parameters used in paper scip_parameters = { "separating/maxrounds": 0, "presolving/maxrestarts": 0, "limits/time": 900, # 15 minute time limit to run, "limits/nodes": 1000, "lp/threads": 4 } # define the environment env = ecole.environment.Branching( observation_function=observation_functions, reward_function=-1*ecole.reward.DualIntegral(), scip_params=scip_parameters, information_function=None ) env.seed(cfg.training.seed) # seed environment # define actor network actor = GNNActor() critic_main = GNNCritic() critic_target = GNNCritic() def initialize_weights(m): if type(m) == torch.nn.Linear: torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: m.bias.data.fill_(0.01) actor.apply(initialize_weights) critic_main.apply(initialize_weights) critic_target.apply(initialize_weights) # define an optimizer actor_opt = torch.optim.Adam(actor.parameters(), lr=cfg.training.lr) critic_opt = torch.optim.Adam(critic_main.parameters(), lr=cfg.training.lr) fig, (ax1, ax2) = plt.subplots(1, 2) counts = [] actor_loss_arr = [] critic_loss_arr = [] for episode_i in range(cfg.training.n_episodes): instance = next(gen) actor_loss_epoch, critic_loss_epoch = 0, 0 valid_epochs = 5 for n in range(5): observation, action_set, _, done, info = env.reset(instance) if done: break replay_buffer = Replay(cfg.max_buffer_size) expected_return = 0 while not done: # Convert np ecole observation to tensor state_tensor = to_state(observation, cfg.device) action_set_tensor = torch.tensor(action_set, dtype=torch.int32) m = state_tensor[0].shape[0] n = state_tensor[-1].shape[0] complexity = (n + m)/max(n, m) # pass state to actor and get distribution and create valid action mask logits = actor(*state_tensor) mask = torch.zeros_like(logits, dtype=torch.bool) mask[action_set] = True # sample from action distribution action_dist = MaskedCategorical(logits=logits, mask=mask) action = action_dist.sample() # take action, and go to the next state next_observation, next_action_set, reward, done, _ = env.step(action.item()) reward = reward / complexity expected_return += reward reward_tensor = to_tensor(reward) done_tensor = torch.tensor(done, dtype=torch.int32) # record in replay buffer replay_buffer.push( state_tensor, # current state action, # action taken (store as tensor) reward_tensor, # reward received action_set_tensor, # action set, done_tensor # mark when episode is finished ) if done and len(replay_buffer) == 1: valid_epochs -= 1 # Update current observation observation = next_observation action_set = next_action_set if done and len(replay_buffer) > 1: # Train actor and critic networks using the replay buffer actor_loss, critic_loss = train_GAE(cfg, actor, critic_main, critic_target, actor_opt, critic_opt, replay_buffer) print(f"episode: {episode_i}, actor loss: {actor_loss:>.4f}, critic loss: {critic_loss:>.4f}, return: {expected_return}") actor_loss_arr.append(actor_loss) critic_loss_arr.append(critic_loss) ax1.clear() ax2.clear() ax1.plot(actor_loss_arr, label=f'Actor Loss') ax1.legend() ax1.grid() ax2.plot(critic_loss_arr, label=f'Critic Loss') ax2.legend() ax2.grid() plt.pause(0.5) plt.show() path = os.path.join(os.getcwd(), "models") name = "learn2branch-set_cover-actor_critic-gae.pt" create_folder(path) torch.save(actor.state_dict(), os.path.join(path, name)) if __name__ == "__main__": main()