EECE571F-project / ml4co-competition / submissions / dual / training / train.py
train.py
Raw
import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.reward
import ecole.typing

import matplotlib.pyplot as plt
import torch
import numpy as np

from torchrl.modules import MaskedCategorical
from torch.distributions import Categorical
import sys

sys.path.append('.')

from common.environments import Branching as Environment
from common.rewards import TimeLimitPrimalIntegral as BoundIntegral
from submissions.dual.model import GNNPolicy as Actor
from submissions.dual.model import GNNCritic as Critic
from submissions.dual.model import PreNormLayer, BipartiteGraphConvolution

from algorithms import off_PAC
from submissions.dual.training.utils import ReplayBuffer, Transition
# from baseline.dual.train_files.utilities import BipartiteNodeData
from submissions.dual.training.utils import BipartiteNodeData
import random
import numpy as np
import neptune
import argparse
import pathlib
import json

def format_observation(sample_observation, sample_action, sample_action_set, record_action=True):

    constraint_features = sample_observation.row_features
    variable_features = sample_observation.variable_features
    edge_indices = sample_observation.edge_features.indices
    edge_features = sample_observation.edge_features.values

    variable_features = np.delete(variable_features, 14, axis=1)
    variable_features = np.delete(variable_features, 13, axis=1)

    constraint_features = torch.FloatTensor(constraint_features)
    edge_indices = torch.LongTensor(edge_indices.astype(np.int32))
    edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1))
    variable_features = torch.FloatTensor(variable_features)

    
    candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))

    if record_action:
        candidate_choice = torch.where(candidates == sample_action)[0][0]  # action index relative to candidates
    else:
        candidate_choice = torch.tensor(-1, dtype=torch.long)

    graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features,
                candidates, candidate_choice)
    
    graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]

    return graph


def main() -> None:
    parser = argparse.ArgumentParser()
    debug = False

    import os
    path = os.path.join(os.getcwd(), "submissions/dual/agents/trained_models/set_cover")
    name_actor = "actor-off_policy-pre_train-nodes.pt"
    name_critic = "critic-off_policy-pre_train-nodes.pt"

    device = torch.device('cpu')
    if not debug:
        run = neptune.init_run(
            project="571-project/learn-to-branch",
        )

    args = parser.parse_args()
    time_limit = 15*60
    memory_limit = 8796093022207  # maximum
    args.problem = 'set_cover'
    args.task = 'dual'
    args.folder = 'train'
    args.expert = 'GCN'

    seed = 3
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if args.problem == 'set_cover':
        expert_file = 'submissions/dual/agents/trained_models/set_cover/best_params.pkl'
        instances = ecole.instance.SetCoverGenerator(n_rows=500, n_cols=500, density=0.05)
        instances.seed(seed)
    elif args.problem == 'item_placement':
        expert_file = 'submissions/dual/agents/trained_models/item_placement/best_params.pkl'
        instances_path = pathlib.Path(f"instances/1_item_placement/train/")
        instances = list(instances_path.glob('*.mps.gz'))
        # instances = ecole.instance.CombinatorialAuctionGenerator(n_items=100, n_bids=250)
        # instances.seed(seed)
    elif args.problem == 'capacitated_facility_location':
        expert_file = 'submissions/dual/agents/trained_models/capacitated_facility_location/best_params.pkl'
        instances = ecole.instance.CapacitatedFacilityLocationGenerator(n_customers=100, n_facilities=50)
        instances.seed(seed)            
    
    
    # define actor network and load pre-trained weights 
    actor = Actor().to(device)
    critic = Critic().to(device)
    target_critic = Critic().to(device)

    if args.expert == 'GCN':        

        print()
        print("Loading expert policy...")
        
        # Load expert
        from submissions.dual.agents.baseline import Policy
        expert = Policy(args.problem)
        observation_function = ecole.observation.NodeBipartite()
        
        print()
        print("Loading actor parameters...")
        actor.load_state_dict(torch.load(expert_file, weights_only=True))
        
        # freeze graph layers
        for name, module in actor.named_modules():
            if isinstance(module, BipartiteGraphConvolution):
                print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}")
                for param in module.parameters():
                    param.requires_grad = False
        
        print()
        print("Loading critic parameters...")
        
        # Pre-load expert into critic
        gcnn_pretrained_state_dict = torch.load(expert_file, map_location="cpu")
        critic_state_dict = critic.state_dict()
        filtered_state_dict = {}
        for name, param in gcnn_pretrained_state_dict.items():
            if name in critic_state_dict and param.shape == critic_state_dict[name].shape:
                filtered_state_dict[name] = param
        critic_state_dict.update(filtered_state_dict)
        critic.load_state_dict(critic_state_dict)
        
        # target critic
        target_critic_state_dict = target_critic.state_dict()
        filtered_state_dict_target = {}
        for name, param in gcnn_pretrained_state_dict.items():
            if name in target_critic_state_dict and param.shape == target_critic_state_dict[name].shape:
                filtered_state_dict_target[name] = param
        target_critic_state_dict.update(filtered_state_dict_target)
        target_critic.load_state_dict(target_critic_state_dict)


        # freeze graph layers
        for name, module in critic.named_modules():
            if isinstance(module, BipartiteGraphConvolution):
                print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}")
                for param in module.parameters():
                    param.requires_grad = False
        
        for name, module in target_critic.named_modules():
            if isinstance(module, BipartiteGraphConvolution):
                print(f"Freezing parameters of BipartiteGraphConvolution in module: {name}")
                for param in module.parameters():
                    param.requires_grad = False
        

    elif args.expert == 'strong_branch':
        from submissions.dual.agents.strong_branch import Policy
        expert = Policy(args.problem)
        observation_function = (
            ecole.observation.NodeBipartite(),
            ecole.observation.StrongBranchingScores()
        )
    else:
        raise NotImplementedError
    
    for p1, p2 in zip(critic.parameters(), target_critic.parameters()):
        assert (p1.storage().data_ptr() != p2.storage().data_ptr()), "Weights are shared!"

    # define optimizer for actor
    actor_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, actor.parameters()), lr=1e-4, amsgrad=True)
    critic_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, critic.parameters()), lr=1e-4, amsgrad=True)
    
    
    for episode, instance in enumerate(instances):

        if args.problem == 'item_placement':
            with open(instance.with_name(instance.stem).with_suffix('.json')) as f:
                instance_info = json.load(f)
            # set up the reward function parameters for that instance
            initial_primal_bound = instance_info["primal_bound"]
            initial_dual_bound = instance_info["dual_bound"]
        else:
            # set up the reward function parameters for that instance
            initial_primal_bound = instance.primal_bound
            initial_dual_bound = instance.dual_bound
        
        objective_offset = 0
        # integral_function = BoundIntegral()
        integral_function = ecole.reward.NNodes()

        env = Environment(
            time_limit=time_limit,  
            observation_function=observation_function,
            reward_function=-integral_function,  # negated integral (minimization)
            scip_params={'limits/memory': memory_limit},
        )
        env.seed(seed)

        # integral_function.set_parameters(
        #         initial_primal_bound=initial_primal_bound,
        #         initial_dual_bound=initial_dual_bound,
        #         objective_offset=objective_offset)
        
        timestep_buffer = ReplayBuffer(capacity=int(1e3))

        # reset the environment
        if args.problem == 'item_placement':
            observation, action_set, reward, done, info = env.reset(str(instance), objective_limit=initial_primal_bound)
        else:
            observation, action_set, reward, done, info = env.reset(instance, objective_limit=initial_primal_bound)
       
        cumulated_reward = 0  # discard initial reward
        
        # loop over the environment
        while not done:
            action, log_prob = expert(action_set, observation)
            next_observation, next_action_set, reward, done, info = env.step(action.item())
            cumulated_reward += reward
            
            state = format_observation(observation, action, action_set)

            experience = (
                state,
                torch.tensor(log_prob, dtype=torch.float32),
                torch.tensor(reward, dtype=torch.float32),
                torch.tensor(done, dtype=torch.long),
            )

            timestep_buffer.push(*experience)

            observation = next_observation
            action_set = next_action_set

        if len(timestep_buffer) <= 1:
            print("episode was too short, skipping...")
            continue
        
        actor_loss, critic_loss = off_PAC(actor, critic, target_critic, actor_optimizer, critic_optimizer, timestep_buffer)
        
        run["Loss/Actor"].append(actor_loss)
        run["Loss/Critic"].append(critic_loss)

        print(f"Run:{episode}, Actor loss: {actor_loss}, Critic loss: {critic_loss}") 
        
        if episode % 10 == 0:
            torch.save(actor.state_dict(), os.path.join(path, name_actor))
            torch.save(critic.state_dict(), os.path.join(path, name_critic))
            print(f"saving...")
        if episode == 99:
            print('Finished training')
            torch.save(actor.state_dict(), os.path.join(path, name_actor))
            torch.save(critic.state_dict(), os.path.join(path, name_critic))
            print(f"saving...")
            break
    


if __name__ == "__main__":
    main()