import ecole import ecole.environment import ecole.instance import ecole.observation import ecole.typing import hydra from omegaconf import DictConfig, OmegaConf import gzip import pickle from pathlib import Path from data import BipartiteDataset from torch_geometric.loader import DataLoader from model import GNNPolicy, GNNActor from observation_functions import HybridBranch import utils import random import torch import torch.nn.functional as F import numpy as np import os def create_folder(dirname): if not os.path.exists(dirname): os.makedirs(dirname) print(f"Created: {dirname}") # def collect_data(cfg: DictConfig, gen: ecole.typing.InstanceGenerator, env: ecole.environment.Environment): # output_file = "learn2branch_gasse/samples" # Path(output_file).mkdir(exist_ok=True) # samples_collected_global = 0 # episode_i = 0 # while samples_collected_global < cfg.training.n_samples: # episode_i += 1 # instance = next(gen) # observation, action_set, _, done, _ = env.reset(instance) # samples_collected_local = 0 # while not done: # (scores, used_sb_scores), graph = observation # action = action_set[scores[action_set].argmax()] # choose the best action (greedy) # if used_sb_scores and samples_collected_global < cfg.training.n_samples: # samples_collected_global += 1 # samples_collected_local += 1 # # Record state, action # data_example = [graph, action, action_set, scores] # # Save samples after each episode and clear memory # with open(f"{output_file}/sample_{cfg.problem}_{episode_i}_{samples_collected_local}.pkl", "wb") as f: # pickle.dump(data_example, f) # # Take action and go to the next state # observation, action_set, _, done, _ = env.step(action) # print(f"Collected {samples_collected_local} samples during episode {episode_i}; progress {samples_collected_global/cfg.training.n_samples:.2%}") def collect_data(cfg: DictConfig, gen: ecole.typing.InstanceGenerator, env: ecole.environment.Environment, generator): output_file = "learn2branch_gasse/samples" Path(output_file).mkdir(exist_ok=True) samples_collected_global = 0 episode_i = 0 while samples_collected_global < cfg.training.n_samples: episode_i += 1 instance = next(gen) observation, action_set, _, done, _ = env.reset(instance) samples_collected_local = 0 while not done: scores, graph = observation action = action_set[scores[action_set].argmax()] # choose the best action (greedy) u = torch.rand(1, generator=generator) if u < cfg.training.sample_prob and samples_collected_global < cfg.training.n_samples: samples_collected_global += 1 samples_collected_local += 1 # Record state, action data_example = [graph, action, action_set, scores] # Save samples after each episode and clear memory with open(f"{output_file}/sb_{cfg.problem}_{episode_i}_{samples_collected_local}.pkl", "wb") as f: pickle.dump(data_example, f) # Take action and go to the next state observation, action_set, _, done, _ = env.step(action) print(f"Collected {samples_collected_local} samples during episode {episode_i}; progress {samples_collected_global/cfg.training.n_samples:.2%}") def train_expert(cfg: DictConfig, train_dl: DataLoader, valid_dl:DataLoader, optimizer: torch.optim.Optimizer, actor: torch.nn.Module): for epoch_i in range(cfg.training.n_epochs): mean_loss = [0, 0] mean_acc = [0, 0] samples_seen = [0, 0] actor.train() for batch in train_dl: # clear gradient buffer optimizer.zero_grad() # sample from actor logits = actor( batch.x_c, batch.edge_index, batch.edge_features, batch.x_v ) # pad since we have batched data logits = utils.pad_tensor(logits[batch.candidates], pad_sizes=batch.nb_candidates) # compute loss loss = F.cross_entropy(logits, batch.candidate_choice) # logits: (32, max_candidates), batch.candidate_choice: (32,) # backprop and step optimizer loss.backward() optimizer.step() # calculate accuracy groundtruth_scores = utils.pad_tensor(batch.candidates_scores, batch.nb_candidates) groundtruth_best_score = torch.max(groundtruth_scores, dim=-1, keepdim=True).values predicted_best_index = logits.max(dim=-1, keepdims=True).indices accuracy = (groundtruth_scores.gather(-1, predicted_best_index) == groundtruth_best_score).float().mean().item() # record loss and accuracy mean_loss[0] += loss.item() * batch.num_graphs mean_acc[0] += accuracy * batch.num_graphs samples_seen[0] += batch.num_graphs actor.eval() with torch.no_grad(): for batch in valid_dl: # sample from actor logits = actor( batch.x_c, batch.edge_index, batch.edge_features, batch.x_v ) # pad since we have batched data logits = utils.pad_tensor(logits[batch.candidates], pad_sizes=batch.nb_candidates) # compute loss loss = F.cross_entropy(logits, batch.candidate_choice) # logits: (32, max_candidates), batch.candidate_choice: (32,) # calculate accuracy groundtruth_scores = utils.pad_tensor(batch.candidates_scores, batch.nb_candidates) groundtruth_best_score = torch.max(groundtruth_scores, dim=-1, keepdim=True).values predicted_best_index = logits.max(dim=-1, keepdims=True).indices accuracy = (groundtruth_scores.gather(-1, predicted_best_index) == groundtruth_best_score).float().mean().item() # record loss and accuracy mean_loss[1] += loss.item() * batch.num_graphs mean_acc[1] += accuracy * batch.num_graphs samples_seen[1] += batch.num_graphs mean_loss[0] /= samples_seen[0] mean_acc[0] /= samples_seen[0] mean_loss[1] /= samples_seen[1] mean_acc[1] /= samples_seen[1] print(f"==== epoch: {epoch_i} ==== ") print(f"(train) mean loss {mean_loss[0]:.3f}, mean acc {mean_acc[0]:.3f}") print(f"(valid) mean loss {mean_loss[1]:.3f}, mean acc {mean_acc[1]:.3f}") path = os.path.join(os.getcwd(), "models") name = "learn2branch-set_cover-1000-v3.pt" create_folder(path) torch.save(actor.state_dict(), os.path.join(path, name)) @hydra.main(version_base=None, config_path="conf", config_name="config") def main(cfg : DictConfig) -> None: # print(OmegaConf.to_yaml(cfg)) torch.manual_seed(cfg.training.seed) random.seed(cfg.training.seed) np.random.seed(cfg.training.seed) generator = torch.Generator() generator.manual_seed(cfg.training.seed) # define problem instance generator if cfg.problem == "set_cover": gen = ecole.instance.SetCoverGenerator( n_rows=cfg.n_rows, n_cols=cfg.n_cols, density=cfg.density, ) gen.seed(cfg.training.seed) # define the observation function if cfg.training.expert == "hybrid_branch": observation_functions = ( HybridBranch(cfg.training.expert_probability), ecole.observation.NodeBipartite() ) if cfg.training.expert == "strong_branch": observation_functions = ( ecole.observation.StrongBranchingScores(), ecole.observation.NodeBipartite() ) # scip parameters used in paper scip_parameters = { "separating/maxrounds": 0, "presolving/maxrestarts": 0, "limits/time": 3600, # 1hr time limit to run } # define the environment env = ecole.environment.Branching( observation_function=observation_functions, scip_params=scip_parameters ) env.seed(cfg.training.seed) # seed environment ## Start collecting training data ## if cfg.sample: collect_data(cfg, gen, env, generator) # # open sample files if they exist # sample_files = Path("learn2branch_gasse/samples/").glob("sample_*.pkl") # samples = [] # for sample_i in sample_files: # with open(sample_i, "rb") as f: # samples.append(pickle.load(f)) # # split into 80/20 # train_samples = samples[: int(0.8*len(samples))] # valid_samples = samples[int(0.8*len(samples)) :] # # make the dataset # train_ds = BipartiteDataset(train_samples) # valid_ds = BipartiteDataset(valid_samples) # train_ds[0] # # declare dataloader for training # train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, generator=generator) # valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=True, generator=generator) # # define actor network # actor = GNNActor() # # define an optimizer # actor_opt = torch.optim.Adam(actor.parameters(), lr=cfg.training.lr) # # train the network # train_expert(cfg, train_dl, valid_dl, actor_opt, actor) if __name__ == "__main__": main()