import ecole
import ecole.environment
import ecole.instance
import ecole.observation
import ecole.typing
import hydra
from omegaconf import DictConfig, OmegaConf
import gzip
import pickle
from pathlib import Path
from data import BipartiteDataset
from torch_geometric.loader import DataLoader
from model import GNNPolicy, GNNActor
from observation_functions import HybridBranch
import utils
import random
import torch
import torch.nn.functional as F
import numpy as np
import os
def create_folder(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
print(f"Created: {dirname}")
# def collect_data(cfg: DictConfig, gen: ecole.typing.InstanceGenerator, env: ecole.environment.Environment):
# output_file = "learn2branch_gasse/samples"
# Path(output_file).mkdir(exist_ok=True)
# samples_collected_global = 0
# episode_i = 0
# while samples_collected_global < cfg.training.n_samples:
# episode_i += 1
# instance = next(gen)
# observation, action_set, _, done, _ = env.reset(instance)
# samples_collected_local = 0
# while not done:
# (scores, used_sb_scores), graph = observation
# action = action_set[scores[action_set].argmax()] # choose the best action (greedy)
# if used_sb_scores and samples_collected_global < cfg.training.n_samples:
# samples_collected_global += 1
# samples_collected_local += 1
# # Record state, action
# data_example = [graph, action, action_set, scores]
# # Save samples after each episode and clear memory
# with open(f"{output_file}/sample_{cfg.problem}_{episode_i}_{samples_collected_local}.pkl", "wb") as f:
# pickle.dump(data_example, f)
# # Take action and go to the next state
# observation, action_set, _, done, _ = env.step(action)
# print(f"Collected {samples_collected_local} samples during episode {episode_i}; progress {samples_collected_global/cfg.training.n_samples:.2%}")
def collect_data(cfg: DictConfig, gen: ecole.typing.InstanceGenerator, env: ecole.environment.Environment, generator):
output_file = "learn2branch_gasse/samples"
Path(output_file).mkdir(exist_ok=True)
samples_collected_global = 0
episode_i = 0
while samples_collected_global < cfg.training.n_samples:
episode_i += 1
instance = next(gen)
observation, action_set, _, done, _ = env.reset(instance)
samples_collected_local = 0
while not done:
scores, graph = observation
action = action_set[scores[action_set].argmax()] # choose the best action (greedy)
u = torch.rand(1, generator=generator)
if u < cfg.training.sample_prob and samples_collected_global < cfg.training.n_samples:
samples_collected_global += 1
samples_collected_local += 1
# Record state, action
data_example = [graph, action, action_set, scores]
# Save samples after each episode and clear memory
with open(f"{output_file}/sb_{cfg.problem}_{episode_i}_{samples_collected_local}.pkl", "wb") as f:
pickle.dump(data_example, f)
# Take action and go to the next state
observation, action_set, _, done, _ = env.step(action)
print(f"Collected {samples_collected_local} samples during episode {episode_i}; progress {samples_collected_global/cfg.training.n_samples:.2%}")
def train_expert(cfg: DictConfig, train_dl: DataLoader, valid_dl:DataLoader, optimizer: torch.optim.Optimizer, actor: torch.nn.Module):
for epoch_i in range(cfg.training.n_epochs):
mean_loss = [0, 0]
mean_acc = [0, 0]
samples_seen = [0, 0]
actor.train()
for batch in train_dl:
# clear gradient buffer
optimizer.zero_grad()
# sample from actor
logits = actor(
batch.x_c,
batch.edge_index,
batch.edge_features,
batch.x_v
)
# pad since we have batched data
logits = utils.pad_tensor(logits[batch.candidates], pad_sizes=batch.nb_candidates)
# compute loss
loss = F.cross_entropy(logits, batch.candidate_choice) # logits: (32, max_candidates), batch.candidate_choice: (32,)
# backprop and step optimizer
loss.backward()
optimizer.step()
# calculate accuracy
groundtruth_scores = utils.pad_tensor(batch.candidates_scores, batch.nb_candidates)
groundtruth_best_score = torch.max(groundtruth_scores, dim=-1, keepdim=True).values
predicted_best_index = logits.max(dim=-1, keepdims=True).indices
accuracy = (groundtruth_scores.gather(-1, predicted_best_index) == groundtruth_best_score).float().mean().item()
# record loss and accuracy
mean_loss[0] += loss.item() * batch.num_graphs
mean_acc[0] += accuracy * batch.num_graphs
samples_seen[0] += batch.num_graphs
actor.eval()
with torch.no_grad():
for batch in valid_dl:
# sample from actor
logits = actor(
batch.x_c,
batch.edge_index,
batch.edge_features,
batch.x_v
)
# pad since we have batched data
logits = utils.pad_tensor(logits[batch.candidates], pad_sizes=batch.nb_candidates)
# compute loss
loss = F.cross_entropy(logits, batch.candidate_choice) # logits: (32, max_candidates), batch.candidate_choice: (32,)
# calculate accuracy
groundtruth_scores = utils.pad_tensor(batch.candidates_scores, batch.nb_candidates)
groundtruth_best_score = torch.max(groundtruth_scores, dim=-1, keepdim=True).values
predicted_best_index = logits.max(dim=-1, keepdims=True).indices
accuracy = (groundtruth_scores.gather(-1, predicted_best_index) == groundtruth_best_score).float().mean().item()
# record loss and accuracy
mean_loss[1] += loss.item() * batch.num_graphs
mean_acc[1] += accuracy * batch.num_graphs
samples_seen[1] += batch.num_graphs
mean_loss[0] /= samples_seen[0]
mean_acc[0] /= samples_seen[0]
mean_loss[1] /= samples_seen[1]
mean_acc[1] /= samples_seen[1]
print(f"==== epoch: {epoch_i} ==== ")
print(f"(train) mean loss {mean_loss[0]:.3f}, mean acc {mean_acc[0]:.3f}")
print(f"(valid) mean loss {mean_loss[1]:.3f}, mean acc {mean_acc[1]:.3f}")
path = os.path.join(os.getcwd(), "models")
name = "learn2branch-set_cover-1000-v3.pt"
create_folder(path)
torch.save(actor.state_dict(), os.path.join(path, name))
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg : DictConfig) -> None:
# print(OmegaConf.to_yaml(cfg))
torch.manual_seed(cfg.training.seed)
random.seed(cfg.training.seed)
np.random.seed(cfg.training.seed)
generator = torch.Generator()
generator.manual_seed(cfg.training.seed)
# define problem instance generator
if cfg.problem == "set_cover":
gen = ecole.instance.SetCoverGenerator(
n_rows=cfg.n_rows,
n_cols=cfg.n_cols,
density=cfg.density,
)
gen.seed(cfg.training.seed)
# define the observation function
if cfg.training.expert == "hybrid_branch":
observation_functions = (
HybridBranch(cfg.training.expert_probability),
ecole.observation.NodeBipartite()
)
if cfg.training.expert == "strong_branch":
observation_functions = (
ecole.observation.StrongBranchingScores(),
ecole.observation.NodeBipartite()
)
# scip parameters used in paper
scip_parameters = {
"separating/maxrounds": 0,
"presolving/maxrestarts": 0,
"limits/time": 3600, # 1hr time limit to run
}
# define the environment
env = ecole.environment.Branching(
observation_function=observation_functions,
scip_params=scip_parameters
)
env.seed(cfg.training.seed) # seed environment
## Start collecting training data ##
if cfg.sample:
collect_data(cfg, gen, env, generator)
# # open sample files if they exist
# sample_files = Path("learn2branch_gasse/samples/").glob("sample_*.pkl")
# samples = []
# for sample_i in sample_files:
# with open(sample_i, "rb") as f:
# samples.append(pickle.load(f))
# # split into 80/20
# train_samples = samples[: int(0.8*len(samples))]
# valid_samples = samples[int(0.8*len(samples)) :]
# # make the dataset
# train_ds = BipartiteDataset(train_samples)
# valid_ds = BipartiteDataset(valid_samples)
# train_ds[0]
# # declare dataloader for training
# train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, generator=generator)
# valid_dl = DataLoader(valid_ds, batch_size=64, shuffle=True, generator=generator)
# # define actor network
# actor = GNNActor()
# # define an optimizer
# actor_opt = torch.optim.Adam(actor.parameters(), lr=cfg.training.lr)
# # train the network
# train_expert(cfg, train_dl, valid_dl, actor_opt, actor)
if __name__ == "__main__":
main()