import ecole import ecole.environment import ecole.instance import ecole.observation import ecole.reward import ecole.typing import matplotlib.pyplot as plt import torch import numpy as np from torchrl.modules import MaskedCategorical from torch.distributions import Categorical import sys sys.path.append('.') from common.environments import Branching as Environment from common.rewards import TimeLimitPrimalIntegral as BoundIntegral from submissions.dual.model import GNNPolicy as Actor from submissions.dual.model import GNNCritic as Critic from submissions.dual.model import PreNormLayer, BipartiteGraphConvolution from algorithms import off_PAC from submissions.dual.training.utils import ReplayBuffer, Transition # from baseline.dual.train_files.utilities import BipartiteNodeData from submissions.dual.training.utils import BipartiteNodeData import random import numpy as np import neptune import argparse import pathlib import json def format_observation(sample_observation, sample_action, sample_action_set, record_action=True): constraint_features = sample_observation.row_features variable_features = sample_observation.variable_features edge_indices = sample_observation.edge_features.indices edge_features = sample_observation.edge_features.values variable_features = np.delete(variable_features, 14, axis=1) variable_features = np.delete(variable_features, 13, axis=1) constraint_features = torch.FloatTensor(constraint_features) edge_indices = torch.LongTensor(edge_indices.astype(np.int32)) edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1)) variable_features = torch.FloatTensor(variable_features) candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32)) if record_action: candidate_choice = torch.where(candidates == sample_action)[0][0] # action index relative to candidates else: candidate_choice = torch.tensor(-1, dtype=torch.long) graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features, candidates, candidate_choice) graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0] return graph def main() -> None: parser = argparse.ArgumentParser() debug = False import os path = os.path.join(os.getcwd(), "submissions/dual/agents/trained_models/set_cover") name_actor = "actor-off_policy-pre_train-nodes.pt" name_critic = "critic-off_policy-pre_train-nodes.pt" device = torch.device('cpu') if not debug: run = neptune.init_run( project="571-project/learn-to-branch", ) args = parser.parse_args() time_limit = 15*60 memory_limit = 8796093022207 # maximum args.problem = 'set_cover' args.task = 'dual' args.folder = 'train' args.expert = 'GCN' seed = 3 torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if args.problem == 'set_cover': expert_file = 'submissions/dual/agents/trained_models/set_cover/best_params.pkl' instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=500, density=0.05) instances.seed(seed) elif args.problem == 'item_placement': expert_file = 'submissions/dual/agents/trained_models/item_placement/best_params.pkl' instances_path = pathlib.Path(f"instances/1_item_placement/train/") instances = list(instances_path.glob('*.mps.gz')) # instances = ecole.instance.CombinatorialAuctionGenerator(n_items=100, n_bids=250) # instances.seed(seed) elif args.problem == 'capacitated_facility_location': expert_file = 'submissions/dual/agents/trained_models/capacitated_facility_location/best_params.pkl' instances = ecole.instance.CapacitatedFacilityLocationGenerator(n_customers=100, n_facilities=50) instances.seed(seed) # define actor network and load pre-trained weights actor = Actor().to(device) critic = Critic().to(device) target_critic = Critic().to(device) if args.expert == 'GCN': print() print("Loading expert policy...") # Load expert from submissions.dual.agents.baseline import Policy expert = Policy(args.problem) observation_function = ecole.observation.NodeBipartite() print() print("Loading actor parameters...") actor.load_state_dict(torch.load(expert_file, weights_only=True)) # freeze graph layers for name, module in actor.named_modules(): if isinstance(module, BipartiteGraphConvolution): print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}") for param in module.parameters(): param.requires_grad = False print() print("Loading critic parameters...") # Pre-load expert into critic gcnn_pretrained_state_dict = torch.load(expert_file, map_location="cpu") critic_state_dict = critic.state_dict() filtered_state_dict = {} for name, param in gcnn_pretrained_state_dict.items(): if name in critic_state_dict and param.shape == critic_state_dict[name].shape: filtered_state_dict[name] = param critic_state_dict.update(filtered_state_dict) critic.load_state_dict(critic_state_dict) # target critic target_critic_state_dict = target_critic.state_dict() filtered_state_dict_target = {} for name, param in gcnn_pretrained_state_dict.items(): if name in target_critic_state_dict and param.shape == target_critic_state_dict[name].shape: filtered_state_dict_target[name] = param target_critic_state_dict.update(filtered_state_dict_target) target_critic.load_state_dict(target_critic_state_dict) # freeze graph layers for name, module in critic.named_modules(): if isinstance(module, BipartiteGraphConvolution): print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}") for param in module.parameters(): param.requires_grad = False for name, module in target_critic.named_modules(): if isinstance(module, BipartiteGraphConvolution): print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}") for param in module.parameters(): param.requires_grad = False elif args.expert == 'strong_branch': from submissions.dual.agents.strong_branch import Policy expert = Policy(args.problem) observation_function = ( ecole.observation.NodeBipartite(), ecole.observation.StrongBranchingScores() ) else: raise NotImplementedError for p1, p2 in zip(critic.parameters(), target_critic.parameters()): assert (p1.storage().data_ptr() != p2.storage().data_ptr()), "Weights are shared!" # define optimizer for actor actor_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, actor.parameters()), lr=1e-4, amsgrad=True) critic_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, critic.parameters()), lr=1e-4, amsgrad=True) for episode, instance in enumerate(instances): if args.problem == 'item_placement': with open(instance.with_name(instance.stem).with_suffix('.json')) as f: instance_info = json.load(f) # set up the reward function parameters for that instance initial_primal_bound = instance_info["primal_bound"] initial_dual_bound = instance_info["dual_bound"] else: # set up the reward function parameters for that instance initial_primal_bound = instance.primal_bound initial_dual_bound = instance.dual_bound objective_offset = 0 # integral_function = BoundIntegral() integral_function = ecole.reward.NNodes() env = Environment( time_limit=time_limit, observation_function=observation_function, reward_function=-integral_function, # negated integral (minimization) scip_params={'limits/memory': memory_limit}, ) env.seed(seed) # integral_function.set_parameters( # initial_primal_bound=initial_primal_bound, # initial_dual_bound=initial_dual_bound, # objective_offset=objective_offset) timestep_buffer = ReplayBuffer(capacity=int(1e3)) # reset the environment if args.problem == 'item_placement': observation, action_set, reward, done, info = env.reset(str(instance), objective_limit=initial_primal_bound) else: observation, action_set, reward, done, info = env.reset(instance, objective_limit=initial_primal_bound) cumulated_reward = 0 # discard initial reward # loop over the environment while not done: action, log_prob = expert(action_set, observation) next_observation, next_action_set, reward, done, info = env.step(action.item()) cumulated_reward += reward state = format_observation(observation, action, action_set) experience = ( state, torch.tensor(log_prob, dtype=torch.float32), torch.tensor(reward, dtype=torch.float32), torch.tensor(done, dtype=torch.long), ) timestep_buffer.push(*experience) observation = next_observation action_set = next_action_set if len(timestep_buffer) <= 1: print("episode was too short, skipping...") continue actor_loss, critic_loss = off_PAC(actor, critic, target_critic, actor_optimizer, critic_optimizer, timestep_buffer) run["Loss/Actor"].append(actor_loss) run["Loss/Critic"].append(critic_loss) print(f"Run:{episode}, Actor loss: {actor_loss}, Critic loss: {critic_loss}") if episode % 10 == 0: torch.save(actor.state_dict(), os.path.join(path, name_actor)) torch.save(critic.state_dict(), os.path.join(path, name_critic)) print(f"saving...") if episode == 99: print('Finished training') torch.save(actor.state_dict(), os.path.join(path, name_actor)) torch.save(critic.state_dict(), os.path.join(path, name_critic)) print(f"saving...") break if __name__ == "__main__": main()