EECE571F-project / learn2branch_gasse / imitation.py
imitation.py
Raw
import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.typing
import hydra
from omegaconf import DictConfig, OmegaConf


import gzip
import pickle
from pathlib import Path

from data import BipartiteDataset
from torch_geometric.loader import DataLoader
from model import GNNPolicy, GNNActor
from observation_functions import HybridBranch
import utils
import random

import torch
import torch.nn.functional as F
import numpy as np
import os

def create_folder(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)
        print(f"Created: {dirname}")



# def collect_data(cfg: DictConfig, gen: ecole.typing.InstanceGenerator, env: ecole.environment.Environment):
#     output_file = "learn2branch_gasse/samples"
#     Path(output_file).mkdir(exist_ok=True)
#     samples_collected_global = 0
#     episode_i = 0

#     while samples_collected_global < cfg.training.n_samples:
            
#             episode_i += 1
#             instance = next(gen)
#             observation, action_set, _, done, _ = env.reset(instance)
#             samples_collected_local = 0
            
#             while not done:
#                 (scores, used_sb_scores), graph = observation
#                 action = action_set[scores[action_set].argmax()]  # choose the best action (greedy)

#                 if used_sb_scores and samples_collected_global < cfg.training.n_samples:
#                     samples_collected_global += 1
#                     samples_collected_local += 1
#                     # Record state, action
#                     data_example = [graph, action, action_set, scores]

#                     # Save samples after each episode and clear memory
#                     with open(f"{output_file}/sample_{cfg.problem}_{episode_i}_{samples_collected_local}.pkl", "wb") as f:
#                         pickle.dump(data_example, f)

#                 # Take action and go to the next state
#                 observation, action_set, _, done, _ = env.step(action)

#             print(f"Collected {samples_collected_local} samples during episode {episode_i}; progress {samples_collected_global/cfg.training.n_samples:.2%}")

def collect_data(cfg: DictConfig, gen: ecole.typing.InstanceGenerator, env: ecole.environment.Environment, generator):
    output_file = "learn2branch_gasse/samples"
    Path(output_file).mkdir(exist_ok=True)
    samples_collected_global = 0
    episode_i = 0

    while samples_collected_global < cfg.training.n_samples:
            
            episode_i += 1
            instance = next(gen)
            observation, action_set, _, done, _ = env.reset(instance)
            samples_collected_local = 0
            
            while not done:
                scores, graph = observation

                action = action_set[scores[action_set].argmax()]  # choose the best action (greedy)
                u = torch.rand(1, generator=generator)

                if u < cfg.training.sample_prob and samples_collected_global < cfg.training.n_samples:
                    samples_collected_global += 1
                    samples_collected_local += 1
                    # Record state, action
                    data_example = [graph, action, action_set, scores]

                    # Save samples after each episode and clear memory
                    with open(f"{output_file}/sb_{cfg.problem}_{episode_i}_{samples_collected_local}.pkl", "wb") as f:
                        pickle.dump(data_example, f)

                # Take action and go to the next state
                observation, action_set, _, done, _ = env.step(action)

            print(f"Collected {samples_collected_local} samples during episode {episode_i}; progress {samples_collected_global/cfg.training.n_samples:.2%}")

def train_expert(cfg: DictConfig, train_dl: DataLoader, valid_dl:DataLoader, optimizer: torch.optim.Optimizer, actor: torch.nn.Module):
    for epoch_i in range(cfg.training.n_epochs):
        mean_loss = [0, 0]
        mean_acc = [0, 0]  
        samples_seen = [0, 0]

        actor.train()
        for batch in train_dl:
            
            # clear gradient buffer
            optimizer.zero_grad()
            
            # sample from actor
            logits = actor(
                batch.x_c,
                batch.edge_index,
                batch.edge_features,
                batch.x_v
            )
            # pad since we have batched data
            logits = utils.pad_tensor(logits[batch.candidates], pad_sizes=batch.nb_candidates)
            
            # compute loss
            loss = F.cross_entropy(logits, batch.candidate_choice) # logits: (32, max_candidates), batch.candidate_choice: (32,)
            
            # backprop and step optimizer
            loss.backward()
            optimizer.step()

            # calculate accuracy
            groundtruth_scores = utils.pad_tensor(batch.candidates_scores, batch.nb_candidates)
            groundtruth_best_score = torch.max(groundtruth_scores, dim=-1, keepdim=True).values
            predicted_best_index = logits.max(dim=-1, keepdims=True).indices
            accuracy = (groundtruth_scores.gather(-1, predicted_best_index) == groundtruth_best_score).float().mean().item()
            
            # record loss and accuracy
            mean_loss[0] += loss.item() * batch.num_graphs
            mean_acc[0] += accuracy * batch.num_graphs
            samples_seen[0] += batch.num_graphs

        actor.eval()
        with torch.no_grad():
            for batch in valid_dl:
                # sample from actor
                logits = actor(
                    batch.x_c,
                    batch.edge_index,
                    batch.edge_features,
                    batch.x_v
                )
                # pad since we have batched data
                logits = utils.pad_tensor(logits[batch.candidates], pad_sizes=batch.nb_candidates)
                
                # compute loss
                loss = F.cross_entropy(logits, batch.candidate_choice) # logits: (32, max_candidates), batch.candidate_choice: (32,)
                
                # calculate accuracy
                groundtruth_scores = utils.pad_tensor(batch.candidates_scores, batch.nb_candidates)
                groundtruth_best_score = torch.max(groundtruth_scores, dim=-1, keepdim=True).values
                predicted_best_index = logits.max(dim=-1, keepdims=True).indices
                accuracy = (groundtruth_scores.gather(-1, predicted_best_index) == groundtruth_best_score).float().mean().item()
                

                # record loss and accuracy
                mean_loss[1] += loss.item() * batch.num_graphs
                mean_acc[1] += accuracy * batch.num_graphs
                samples_seen[1] += batch.num_graphs

        mean_loss[0] /= samples_seen[0]
        mean_acc[0] /= samples_seen[0]
        mean_loss[1] /= samples_seen[1]
        mean_acc[1] /= samples_seen[1]
        print(f"==== epoch: {epoch_i} ==== ")
        print(f"(train) mean loss {mean_loss[0]:.3f}, mean acc {mean_acc[0]:.3f}")
        print(f"(valid) mean loss {mean_loss[1]:.3f}, mean acc {mean_acc[1]:.3f}")


    path = os.path.join(os.getcwd(), "models")
    name = "learn2branch-set_cover-1000-v3.pt"
    create_folder(path)
    torch.save(actor.state_dict(), os.path.join(path, name))


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg : DictConfig) -> None:
    # print(OmegaConf.to_yaml(cfg))
    torch.manual_seed(cfg.training.seed)
    random.seed(cfg.training.seed)
    np.random.seed(cfg.training.seed)
    generator = torch.Generator()
    generator.manual_seed(cfg.training.seed)

    # define problem instance generator
    if cfg.problem == "set_cover": 
        gen = ecole.instance.SetCoverGenerator(
            n_rows=cfg.n_rows,
            n_cols=cfg.n_cols,
            density=cfg.density,
        )
    
    gen.seed(cfg.training.seed)
    
    # define the observation function
    if cfg.training.expert == "hybrid_branch":
        observation_functions = (
            HybridBranch(cfg.training.expert_probability),
            ecole.observation.NodeBipartite()
        )
    if cfg.training.expert == "strong_branch":
        observation_functions = (
            ecole.observation.StrongBranchingScores(),
            ecole.observation.NodeBipartite()
        )

    # scip parameters used in paper
    scip_parameters = {
        "separating/maxrounds": 0,
        "presolving/maxrestarts": 0,
        "limits/time": 3600, # 1hr time limit to run
    }
    
    # define the environment
    env = ecole.environment.Branching(
        observation_function=observation_functions,
        scip_params=scip_parameters
    )

    env.seed(cfg.training.seed) # seed environment

    ## Start collecting training data ##
    if cfg.sample:
        collect_data(cfg, gen, env, generator)


    # # open sample files if they exist
    # sample_files = Path("learn2branch_gasse/samples/").glob("sample_*.pkl")
    # samples = []
    # for sample_i in sample_files:
    #     with open(sample_i, "rb") as f:
    #             samples.append(pickle.load(f))

    # # split into 80/20
    # train_samples = samples[: int(0.8*len(samples))]
    # valid_samples = samples[int(0.8*len(samples)) :]

    # # make the dataset 
    # train_ds = BipartiteDataset(train_samples)
    # valid_ds = BipartiteDataset(valid_samples)

    # train_ds[0]

    # # declare dataloader for training
    # train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, generator=generator)
    # valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=True, generator=generator)

    # # define actor network
    # actor = GNNActor()

    # # define an optimizer
    # actor_opt = torch.optim.Adam(actor.parameters(), lr=cfg.training.lr)

    # # train the network
    # train_expert(cfg, train_dl, valid_dl, actor_opt, actor)








if __name__ == "__main__":
    main()