import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.reward
import ecole.typing
import matplotlib.pyplot as plt
import torch
import numpy as np
from torchrl.modules import MaskedCategorical
from torch.distributions import Categorical
import sys
sys.path.append('.')
from common.environments import Branching as Environment
from common.rewards import TimeLimitPrimalIntegral as BoundIntegral
from submissions.dual.model import GNNPolicy as Actor
from submissions.dual.model import GNNCritic as Critic
from submissions.dual.model import PreNormLayer, BipartiteGraphConvolution
from algorithms import off_PAC
from submissions.dual.training.utils import ReplayBuffer, Transition
# from baseline.dual.train_files.utilities import BipartiteNodeData
from submissions.dual.training.utils import BipartiteNodeData
import random
import numpy as np
import neptune
import argparse
import pathlib
import json
def format_observation(sample_observation, sample_action, sample_action_set, record_action=True):
constraint_features = sample_observation.row_features
variable_features = sample_observation.variable_features
edge_indices = sample_observation.edge_features.indices
edge_features = sample_observation.edge_features.values
variable_features = np.delete(variable_features, 14, axis=1)
variable_features = np.delete(variable_features, 13, axis=1)
constraint_features = torch.FloatTensor(constraint_features)
edge_indices = torch.LongTensor(edge_indices.astype(np.int32))
edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1))
variable_features = torch.FloatTensor(variable_features)
candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))
if record_action:
candidate_choice = torch.where(candidates == sample_action)[0][0] # action index relative to candidates
else:
candidate_choice = torch.tensor(-1, dtype=torch.long)
graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features,
candidates, candidate_choice)
graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
return graph
def main() -> None:
parser = argparse.ArgumentParser()
debug = False
import os
path = os.path.join(os.getcwd(), "submissions/dual/agents/trained_models/set_cover")
name_actor = "actor-off_policy-pre_train-nodes.pt"
name_critic = "critic-off_policy-pre_train-nodes.pt"
device = torch.device('cpu')
if not debug:
run = neptune.init_run(
project="571-project/learn-to-branch",
)
args = parser.parse_args()
time_limit = 15*60
memory_limit = 8796093022207 # maximum
args.problem = 'set_cover'
args.task = 'dual'
args.folder = 'train'
args.expert = 'GCN'
seed = 3
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if args.problem == 'set_cover':
expert_file = 'submissions/dual/agents/trained_models/set_cover/best_params.pkl'
instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=500, density=0.05)
instances.seed(seed)
elif args.problem == 'item_placement':
expert_file = 'submissions/dual/agents/trained_models/item_placement/best_params.pkl'
instances_path = pathlib.Path(f"instances/1_item_placement/train/")
instances = list(instances_path.glob('*.mps.gz'))
# instances = ecole.instance.CombinatorialAuctionGenerator(n_items=100, n_bids=250)
# instances.seed(seed)
elif args.problem == 'capacitated_facility_location':
expert_file = 'submissions/dual/agents/trained_models/capacitated_facility_location/best_params.pkl'
instances = ecole.instance.CapacitatedFacilityLocationGenerator(n_customers=100, n_facilities=50)
instances.seed(seed)
# define actor network and load pre-trained weights
actor = Actor().to(device)
critic = Critic().to(device)
target_critic = Critic().to(device)
if args.expert == 'GCN':
print()
print("Loading expert policy...")
# Load expert
from submissions.dual.agents.baseline import Policy
expert = Policy(args.problem)
observation_function = ecole.observation.NodeBipartite()
print()
print("Loading actor parameters...")
actor.load_state_dict(torch.load(expert_file, weights_only=True))
# freeze graph layers
for name, module in actor.named_modules():
if isinstance(module, BipartiteGraphConvolution):
print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}")
for param in module.parameters():
param.requires_grad = False
print()
print("Loading critic parameters...")
# Pre-load expert into critic
gcnn_pretrained_state_dict = torch.load(expert_file, map_location="cpu")
critic_state_dict = critic.state_dict()
filtered_state_dict = {}
for name, param in gcnn_pretrained_state_dict.items():
if name in critic_state_dict and param.shape == critic_state_dict[name].shape:
filtered_state_dict[name] = param
critic_state_dict.update(filtered_state_dict)
critic.load_state_dict(critic_state_dict)
# target critic
target_critic_state_dict = target_critic.state_dict()
filtered_state_dict_target = {}
for name, param in gcnn_pretrained_state_dict.items():
if name in target_critic_state_dict and param.shape == target_critic_state_dict[name].shape:
filtered_state_dict_target[name] = param
target_critic_state_dict.update(filtered_state_dict_target)
target_critic.load_state_dict(target_critic_state_dict)
# freeze graph layers
for name, module in critic.named_modules():
if isinstance(module, BipartiteGraphConvolution):
print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}")
for param in module.parameters():
param.requires_grad = False
for name, module in target_critic.named_modules():
if isinstance(module, BipartiteGraphConvolution):
print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}")
for param in module.parameters():
param.requires_grad = False
elif args.expert == 'strong_branch':
from submissions.dual.agents.strong_branch import Policy
expert = Policy(args.problem)
observation_function = (
ecole.observation.NodeBipartite(),
ecole.observation.StrongBranchingScores()
)
else:
raise NotImplementedError
for p1, p2 in zip(critic.parameters(), target_critic.parameters()):
assert (p1.storage().data_ptr() != p2.storage().data_ptr()), "Weights are shared!"
# define optimizer for actor
actor_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, actor.parameters()), lr=1e-4, amsgrad=True)
critic_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, critic.parameters()), lr=1e-4, amsgrad=True)
for episode, instance in enumerate(instances):
if args.problem == 'item_placement':
with open(instance.with_name(instance.stem).with_suffix('.json')) as f:
instance_info = json.load(f)
# set up the reward function parameters for that instance
initial_primal_bound = instance_info["primal_bound"]
initial_dual_bound = instance_info["dual_bound"]
else:
# set up the reward function parameters for that instance
initial_primal_bound = instance.primal_bound
initial_dual_bound = instance.dual_bound
objective_offset = 0
# integral_function = BoundIntegral()
integral_function = ecole.reward.NNodes()
env = Environment(
time_limit=time_limit,
observation_function=observation_function,
reward_function=-integral_function, # negated integral (minimization)
scip_params={'limits/memory': memory_limit},
)
env.seed(seed)
# integral_function.set_parameters(
# initial_primal_bound=initial_primal_bound,
# initial_dual_bound=initial_dual_bound,
# objective_offset=objective_offset)
timestep_buffer = ReplayBuffer(capacity=int(1e3))
# reset the environment
if args.problem == 'item_placement':
observation, action_set, reward, done, info = env.reset(str(instance), objective_limit=initial_primal_bound)
else:
observation, action_set, reward, done, info = env.reset(instance, objective_limit=initial_primal_bound)
cumulated_reward = 0 # discard initial reward
# loop over the environment
while not done:
action, log_prob = expert(action_set, observation)
next_observation, next_action_set, reward, done, info = env.step(action.item())
cumulated_reward += reward
state = format_observation(observation, action, action_set)
experience = (
state,
torch.tensor(log_prob, dtype=torch.float32),
torch.tensor(reward, dtype=torch.float32),
torch.tensor(done, dtype=torch.long),
)
timestep_buffer.push(*experience)
observation = next_observation
action_set = next_action_set
if len(timestep_buffer) <= 1:
print("episode was too short, skipping...")
continue
actor_loss, critic_loss = off_PAC(actor, critic, target_critic, actor_optimizer, critic_optimizer, timestep_buffer)
run["Loss/Actor"].append(actor_loss)
run["Loss/Critic"].append(critic_loss)
print(f"Run:{episode}, Actor loss: {actor_loss}, Critic loss: {critic_loss}")
if episode % 10 == 0:
torch.save(actor.state_dict(), os.path.join(path, name_actor))
torch.save(critic.state_dict(), os.path.join(path, name_critic))
print(f"saving...")
if episode == 99:
print('Finished training')
torch.save(actor.state_dict(), os.path.join(path, name_actor))
torch.save(critic.state_dict(), os.path.join(path, name_critic))
print(f"saving...")
break
if __name__ == "__main__":
main()