import ecole import ecole.environment import ecole.instance import ecole.observation import ecole.reward import ecole.typing import matplotlib.pyplot as plt import torch import numpy as np from torchrl.modules import MaskedCategorical from torch.distributions import Categorical from replay import Replay, Transition from environments import Branching as Environment from rewards import TimeLimitDualIntegral as BoundIntegral from agents.model import GNNActor, GNNCritic from agents.dual import ObservationFunction from algorithms import REINFORCE, GAE import pathlib import json import os import random import numpy as np import time import neptune import argparse def main() -> None: parser = argparse.ArgumentParser() debug = False device = torch.device('cpu') if not debug: run = neptune.init_run( project="571-project/learn-to-branch", ) args = parser.parse_args() time_limit = 5*60 memory_limit = 8*1024 # 8GB replay_size = int(1e3) num_epochs = 100 args.problem = 'set_cover' args.task = 'dual' args.folder = 'train' if args.problem == 'item_placement': instances_path = pathlib.Path(f"../../instances/1_item_placement/{args.folder}/") results_file = pathlib.Path(f"results/{args.task}/1_item_placement.csv") elif args.problem == 'load_balancing': instances_path = pathlib.Path(f"../../instances/2_load_balancing/{args.folder}/") results_file = pathlib.Path(f"results/{args.task}/2_load_balancing.csv") elif args.problem == 'anonymous': instances_path = pathlib.Path(f"../../instances/3_anonymous/{args.folder}/") results_file = pathlib.Path(f"results/{args.task}/3_anonymous.csv") print(f"Processing instances from {instances_path.resolve()}") instance_files = list(instances_path.glob('*.mps.gz')) # define actor network actor = GNNActor().to(device) critic_main = GNNCritic().to(device) for seed, instance in enumerate(instance_files): observation_function = ecole.observation.NodeBipartite() integral_function = BoundIntegral() env = Environment( time_limit=time_limit, observation_function=observation_function, reward_function=-integral_function, # negated integral (minimization) scip_params={'limits/memory': memory_limit}, ) env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) actor_opt = torch.optim.Adam(actor.parameters(), lr=1e-4) critic_opt = torch.optim.Adam(critic_main.parameters(), lr=1e-4, weight_decay=0.01) # set up the reward function parameters for that instance initial_primal_bound = instance['primal_bound'] initial_dual_bound = instance['primal_bound'] objective_offset = 0 integral_function.set_parameters( initial_primal_bound=initial_primal_bound, initial_dual_bound=initial_dual_bound, objective_offset=objective_offset) print() print(f"Instance {instance.name}") print(f" seed: {seed}") print(f" initial primal bound: {initial_primal_bound}") print(f" initial dual bound: {initial_dual_bound}") print(f" objective offset: {objective_offset}") for epoch_i in range(num_epochs): # reset the environment observation, action_set, reward, done, info = env.reset(str(instance), objective_limit=initial_primal_bound) cumulated_reward = 0 # discard initial reward replay_buffer = Replay(replay_size) # loop over the environment while not done: variable_features = observation.variable_features variable_features = np.delete(variable_features, 14, axis=1) variable_features = np.delete(variable_features, 13, axis=1) observation_tuple = ( torch.from_numpy(observation.row_features.astype(np.float32)).to(device), torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(device), torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(device), torch.from_numpy(variable_features.astype(np.float32)).to(device), ) action_set_tensor = torch.tensor(action_set, dtype=torch.int32, device=device) # pass state to actor and get distribution and create valid action mask with torch.no_grad(): logits = actor(*observation_tuple) mask = torch.zeros_like(logits, dtype=torch.bool) mask[action_set] = True if torch.all(mask == True): action_dist = Categorical(logits=logits) else: action_dist = MaskedCategorical(logits=logits, mask=mask) # sample from action distribution action = action_dist.sample() old_log_probs = action_dist.log_prob(action).detach() observation, action_set, reward, done, info = env.step(action.item()) cumulated_reward += reward done_tensor = torch.tensor(done, dtype=torch.int32, device=device) reward_tensor = torch.tensor(reward, dtype=torch.float32, device=device) replay_buffer.push( observation_tuple, # current state action, # action taken (store as tensor) reward_tensor, # reward received action_set_tensor, # action set, done_tensor, # mark when episode is finished old_log_probs ) if len(replay_buffer) > 1: actor_loss, critic_loss, advantages, entropy_values, kl_div = REINFORCE(actor, critic_main, None, actor_opt, critic_opt, replay_buffer, epoch_i) if not debug: run["Loss/Actor"].append(actor_loss) run["Loss/Critic"].append(critic_loss) run["Info/KL_divergence"].append(kl_div) run["Info/Entropy"].append(entropy_values) run["Info/Advantages"].append(advantages) print(f"seed: {seed}, actor loss: {actor_loss:>.4f}, critic loss: {critic_loss:>.4f}, return: {cumulated_reward}") # actor_scheduler.step() # critic_scheduler.step() path = os.path.join(os.getcwd(), "models") name = "learn2branch-set_cover-250-250-actor_critic-gae.pt" torch.save(actor.state_dict(), os.path.join(path, name)) # run.stop() # except KeyboardInterrupt: if not debug: run.stop() if __name__ == "__main__": main()