import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.reward
import ecole.typing
import matplotlib.pyplot as plt
import torch
import numpy as np
from torchrl.modules import MaskedCategorical
from torch.distributions import Categorical
from replay import Replay, Transition
from environments import Branching as Environment
from rewards import TimeLimitDualIntegral as BoundIntegral
from agents.model import GNNActor, GNNCritic
from agents.dual import ObservationFunction
from algorithms import REINFORCE, GAE
import pathlib
import json
import os
import random
import numpy as np
import time
import neptune
import argparse
def main() -> None:
parser = argparse.ArgumentParser()
debug = False
device = torch.device('cpu')
if not debug:
run = neptune.init_run(
project="571-project/learn-to-branch",
)
args = parser.parse_args()
time_limit = 5*60
memory_limit = 8*1024 # 8GB
replay_size = int(1e3)
num_epochs = 100
args.problem = 'set_cover'
args.task = 'dual'
args.folder = 'train'
if args.problem == 'item_placement':
instances_path = pathlib.Path(f"../../instances/1_item_placement/{args.folder}/")
results_file = pathlib.Path(f"results/{args.task}/1_item_placement.csv")
elif args.problem == 'load_balancing':
instances_path = pathlib.Path(f"../../instances/2_load_balancing/{args.folder}/")
results_file = pathlib.Path(f"results/{args.task}/2_load_balancing.csv")
elif args.problem == 'anonymous':
instances_path = pathlib.Path(f"../../instances/3_anonymous/{args.folder}/")
results_file = pathlib.Path(f"results/{args.task}/3_anonymous.csv")
print(f"Processing instances from {instances_path.resolve()}")
instance_files = list(instances_path.glob('*.mps.gz'))
# define actor network
actor = GNNActor().to(device)
critic_main = GNNCritic().to(device)
for seed, instance in enumerate(instance_files):
observation_function = ecole.observation.NodeBipartite()
integral_function = BoundIntegral()
env = Environment(
time_limit=time_limit,
observation_function=observation_function,
reward_function=-integral_function, # negated integral (minimization)
scip_params={'limits/memory': memory_limit},
)
env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
actor_opt = torch.optim.Adam(actor.parameters(), lr=1e-4)
critic_opt = torch.optim.Adam(critic_main.parameters(), lr=1e-4, weight_decay=0.01)
# set up the reward function parameters for that instance
initial_primal_bound = instance['primal_bound']
initial_dual_bound = instance['primal_bound']
objective_offset = 0
integral_function.set_parameters(
initial_primal_bound=initial_primal_bound,
initial_dual_bound=initial_dual_bound,
objective_offset=objective_offset)
print()
print(f"Instance {instance.name}")
print(f" seed: {seed}")
print(f" initial primal bound: {initial_primal_bound}")
print(f" initial dual bound: {initial_dual_bound}")
print(f" objective offset: {objective_offset}")
for epoch_i in range(num_epochs):
# reset the environment
observation, action_set, reward, done, info = env.reset(str(instance), objective_limit=initial_primal_bound)
cumulated_reward = 0 # discard initial reward
replay_buffer = Replay(replay_size)
# loop over the environment
while not done:
variable_features = observation.variable_features
variable_features = np.delete(variable_features, 14, axis=1)
variable_features = np.delete(variable_features, 13, axis=1)
observation_tuple = (
torch.from_numpy(observation.row_features.astype(np.float32)).to(device),
torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(device),
torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(device),
torch.from_numpy(variable_features.astype(np.float32)).to(device),
)
action_set_tensor = torch.tensor(action_set, dtype=torch.int32, device=device)
# pass state to actor and get distribution and create valid action mask
with torch.no_grad():
logits = actor(*observation_tuple)
mask = torch.zeros_like(logits, dtype=torch.bool)
mask[action_set] = True
if torch.all(mask == True):
action_dist = Categorical(logits=logits)
else:
action_dist = MaskedCategorical(logits=logits, mask=mask)
# sample from action distribution
action = action_dist.sample()
old_log_probs = action_dist.log_prob(action).detach()
observation, action_set, reward, done, info = env.step(action.item())
cumulated_reward += reward
done_tensor = torch.tensor(done, dtype=torch.int32, device=device)
reward_tensor = torch.tensor(reward, dtype=torch.float32, device=device)
replay_buffer.push(
observation_tuple, # current state
action, # action taken (store as tensor)
reward_tensor, # reward received
action_set_tensor, # action set,
done_tensor, # mark when episode is finished
old_log_probs
)
if len(replay_buffer) > 1:
actor_loss, critic_loss, advantages, entropy_values, kl_div = REINFORCE(actor, critic_main, None, actor_opt, critic_opt, replay_buffer, epoch_i)
if not debug:
run["Loss/Actor"].append(actor_loss)
run["Loss/Critic"].append(critic_loss)
run["Info/KL_divergence"].append(kl_div)
run["Info/Entropy"].append(entropy_values)
run["Info/Advantages"].append(advantages)
print(f"seed: {seed}, actor loss: {actor_loss:>.4f}, critic loss: {critic_loss:>.4f}, return: {cumulated_reward}")
# actor_scheduler.step()
# critic_scheduler.step()
path = os.path.join(os.getcwd(), "models")
name = "learn2branch-set_cover-250-250-actor_critic-gae.pt"
torch.save(actor.state_dict(), os.path.join(path, name))
# run.stop()
# except KeyboardInterrupt:
if not debug:
run.stop()
if __name__ == "__main__":
main()