EECE571F-project / learn2branch_self / train.py
train.py
Raw
import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.reward
import ecole.typing

import matplotlib.pyplot as plt
import torch
import numpy as np

from torchrl.modules import MaskedCategorical
from torch.distributions import Categorical



from replay import Replay, Transition
from environments import Branching as Environment
from rewards import TimeLimitDualIntegral as BoundIntegral
from agents.model import GNNActor, GNNCritic
from agents.dual import ObservationFunction
from algorithms import REINFORCE, GAE

import pathlib
import json
import os
import random
import numpy as np
import time
import neptune
import argparse




def main() -> None:
    parser = argparse.ArgumentParser()
    debug = False

    device = torch.device('cpu')
    if not debug:
        run = neptune.init_run(
            project="571-project/learn-to-branch",
        )

    args = parser.parse_args()
    time_limit = 5*60
    memory_limit = 8*1024  # 8GB
    replay_size = int(1e3)
    num_epochs = 100
    args.problem = 'set_cover'
    args.task = 'dual'
    args.folder = 'train'

    if args.problem == 'item_placement':
        instances_path = pathlib.Path(f"../../instances/1_item_placement/{args.folder}/")
        results_file = pathlib.Path(f"results/{args.task}/1_item_placement.csv")
    elif args.problem == 'load_balancing':
        instances_path = pathlib.Path(f"../../instances/2_load_balancing/{args.folder}/")
        results_file = pathlib.Path(f"results/{args.task}/2_load_balancing.csv")
    elif args.problem == 'anonymous':
        instances_path = pathlib.Path(f"../../instances/3_anonymous/{args.folder}/")
        results_file = pathlib.Path(f"results/{args.task}/3_anonymous.csv")

    print(f"Processing instances from {instances_path.resolve()}")
    instance_files = list(instances_path.glob('*.mps.gz'))

    
    # define actor network
    actor = GNNActor().to(device)
    critic_main = GNNCritic().to(device)


    for seed, instance in enumerate(instance_files):
        observation_function = ecole.observation.NodeBipartite()
        integral_function = BoundIntegral()

        env = Environment(
            time_limit=time_limit,  
            observation_function=observation_function,
            reward_function=-integral_function,  # negated integral (minimization)
            scip_params={'limits/memory': memory_limit},
        )
        
        env.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        actor_opt = torch.optim.Adam(actor.parameters(), lr=1e-4)
        critic_opt = torch.optim.Adam(critic_main.parameters(), lr=1e-4, weight_decay=0.01)

        # set up the reward function parameters for that instance
        initial_primal_bound = instance['primal_bound']
        initial_dual_bound = instance['primal_bound']
        objective_offset = 0

        integral_function.set_parameters(
                initial_primal_bound=initial_primal_bound,
                initial_dual_bound=initial_dual_bound,
                objective_offset=objective_offset)
        
        print()
        print(f"Instance {instance.name}")
        print(f"  seed: {seed}")
        print(f"  initial primal bound: {initial_primal_bound}")
        print(f"  initial dual bound: {initial_dual_bound}")
        print(f"  objective offset: {objective_offset}")

        for epoch_i in range(num_epochs):
            # reset the environment
            observation, action_set, reward, done, info = env.reset(str(instance), objective_limit=initial_primal_bound)

            cumulated_reward = 0  # discard initial reward
            replay_buffer = Replay(replay_size)
            
            # loop over the environment
            while not done:
                variable_features = observation.variable_features
                variable_features = np.delete(variable_features, 14, axis=1)
                variable_features = np.delete(variable_features, 13, axis=1)

                observation_tuple = (
                    torch.from_numpy(observation.row_features.astype(np.float32)).to(device),
                    torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(device),
                    torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(device),
                    torch.from_numpy(variable_features.astype(np.float32)).to(device),
                )

                action_set_tensor = torch.tensor(action_set, dtype=torch.int32, device=device)

                # pass state to actor and get distribution and create valid action mask
                with torch.no_grad():
                    logits = actor(*observation_tuple)
                    mask = torch.zeros_like(logits, dtype=torch.bool)
                    mask[action_set] = True

                    if torch.all(mask == True):
                        action_dist = Categorical(logits=logits)
                    else:
                        action_dist = MaskedCategorical(logits=logits, mask=mask)

                    
                    # sample from action distribution
                    action = action_dist.sample()
                    old_log_probs = action_dist.log_prob(action).detach()

                observation, action_set, reward, done, info = env.step(action.item())

                cumulated_reward += reward
                done_tensor = torch.tensor(done, dtype=torch.int32, device=device)
                reward_tensor = torch.tensor(reward, dtype=torch.float32, device=device)

                replay_buffer.push(
                        observation_tuple,  # current state
                        action,             # action taken (store as tensor)
                        reward_tensor,      # reward received
                        action_set_tensor,  # action set,
                        done_tensor,         # mark when episode is finished
                        old_log_probs
                    )
            
            if len(replay_buffer) > 1:
                actor_loss, critic_loss, advantages, entropy_values, kl_div = REINFORCE(actor, critic_main, None, actor_opt, critic_opt, replay_buffer, epoch_i)
                
                if not debug:
                    run["Loss/Actor"].append(actor_loss)
                    run["Loss/Critic"].append(critic_loss)
                    run["Info/KL_divergence"].append(kl_div)
                    run["Info/Entropy"].append(entropy_values)
                    run["Info/Advantages"].append(advantages)

                print(f"seed: {seed}, actor loss: {actor_loss:>.4f}, critic loss: {critic_loss:>.4f}, return: {cumulated_reward}")
        
                # actor_scheduler.step()
                # critic_scheduler.step()

        path = os.path.join(os.getcwd(), "models")
        name = "learn2branch-set_cover-250-250-actor_critic-gae.pt"
        torch.save(actor.state_dict(), os.path.join(path, name))
        # run.stop()
# except KeyboardInterrupt:
    if not debug:
        run.stop()

if __name__ == "__main__":
    main()