import os import sys import glob import argparse import pathlib import numpy as np def pretrain(policy, pretrain_loader): """ Pre-trains all PreNorm layers in the model. Parameters ---------- policy : torch.nn.Module Model to pre-train. pretrain_loader : torch_geometric.data.DataLoader Pre-loaded dataset of pre-training samples. Returns ------- i : int Number of pre-trained layers. """ policy.pre_train_init() i = 0 while True: for batch in pretrain_loader: batch.to(device) if not policy.pre_train(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features): break if policy.pre_train_next() is None: break i += 1 return i def process(policy, data_loader, top_k=[1, 3, 5, 10], optimizer=None): """ Process samples. If an optimizer is given, also train on those samples. Parameters ---------- policy : torch.nn.Module Model to train/evaluate. data_loader : torch_geometric.data.DataLoader Pre-loaded dataset of training samples. top_k : list Accuracy will be computed for the top k elements, for k in this list. optimizer : torch.optim (optional) Optimizer object. If not None, will be used for updating the model parameters. Returns ------- mean_loss : float in [0, 1e+20] Mean cross entropy loss. mean_kacc : np.ndarray Mean top k accuracy, for k in the user-provided list top_k. """ mean_loss = 0 mean_kacc = np.zeros(len(top_k)) n_samples_processed = 0 with torch.set_grad_enabled(optimizer is not None): for batch in data_loader: batch = batch.to(device) logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features) logits = pad_tensor(logits[batch.candidates], batch.nb_candidates) cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean') # if an optimizer is provided, update parameters if optimizer is not None: optimizer.zero_grad() cross_entropy_loss.backward() optimizer.step() true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates) true_bestscore = true_scores.max(dim=-1, keepdims=True).values # calculate top k accuracy kacc = [] for k in top_k: # check if there are at least k candidates if logits.size()[-1] < k: kacc.append(1.0) continue pred_top_k = logits.topk(k).indices pred_top_k_true_scores = true_scores.gather(-1, pred_top_k) accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item() kacc.append(accuracy) kacc = np.asarray(kacc) mean_loss += cross_entropy_loss.item() * batch.num_graphs mean_kacc += kacc * batch.num_graphs n_samples_processed += batch.num_graphs mean_loss /= n_samples_processed mean_kacc /= n_samples_processed return mean_loss, mean_kacc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( 'problem', help='MILP instance type to process.', choices=['item_placement', 'load_balancing', 'anonymous'], ) parser.add_argument( '-s', '--seed', help='Random generator seed.', type=int, default=0, ) parser.add_argument( '-g', '--gpu', help='CUDA GPU id (-1 for CPU).', type=int, default=0, ) args = parser.parse_args() # hyper parameters max_epochs = 1000 batch_size = 12 pretrain_batch_size = 128 valid_batch_size = 128 lr = 1e-3 top_k = [1, 3, 5, 10] # get sample directory if args.problem == 'item_placement': train_files = glob.glob('train_files/samples/1_item_placement/train/sample_*.pkl') valid_files = glob.glob('train_files/samples/1_item_placement/valid/sample_*.pkl') running_dir = 'train_files/trained_models/item_placement' elif args.problem == 'load_balancing': train_files = glob.glob('train_files/samples/2_load_balancing/train/sample_*.pkl') valid_files = glob.glob('train_files/samples/2_load_balancing/valid/sample_*.pkl') running_dir = 'train_files/trained_models/load_balancing' elif args.problem == 'anonymous': train_files = glob.glob('train_files/samples/3_anonymous/train/sample_*.pkl') valid_files = glob.glob('train_files/samples/3_anonymous/valid/sample_*.pkl') running_dir = 'train_files/trained_models/anonymous' else: raise NotImplementedError pretrain_files = [f for i, f in enumerate(train_files) if i % 10 == 0] # working directory setup os.makedirs(running_dir, exist_ok=True) # cuda setup if args.gpu == -1: os.environ['CUDA_VISIBLE_DEVICES'] = '' device = "cpu" else: os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}' device = f"cuda:0" # import pytorch **after** cuda setup import torch import torch.nn.functional as F import torch_geometric from utilities import log, pad_tensor, GraphDataset, Scheduler sys.path.insert(0,'.') from model import GNNPolicy # randomization setup rng = np.random.RandomState(args.seed) torch.manual_seed(args.seed) # logging setup logfile = os.path.join(running_dir, 'train_log.txt') if os.path.exists(logfile): os.remove(logfile) log(f"max_epochs: {max_epochs}", logfile) log(f"batch_size: {batch_size}", logfile) log(f"pretrain_batch_size: {pretrain_batch_size}", logfile) log(f"valid_batch_size : {valid_batch_size }", logfile) log(f"lr: {lr}", logfile) log(f"top_k: {top_k}", logfile) log(f"gpu: {args.gpu}", logfile) log(f"seed {args.seed}", logfile) # data setup valid_data = GraphDataset(valid_files) pretrain_data = GraphDataset(pretrain_files) valid_loader = torch_geometric.data.DataLoader(valid_data, valid_batch_size, shuffle=False) pretrain_loader = torch_geometric.data.DataLoader(pretrain_data, pretrain_batch_size, shuffle=False) policy = GNNPolicy().to(device) optimizer = torch.optim.Adam(policy.parameters(), lr=lr) scheduler = Scheduler(optimizer, mode='min', patience=10, factor=0.2, verbose=True) for epoch in range(max_epochs + 1): log(f"EPOCH {epoch}...", logfile) if epoch == 0: n = pretrain(policy, pretrain_loader) log(f"PRETRAINED {n} LAYERS", logfile) else: epoch_train_files = rng.choice(train_files, int(np.floor(10000/batch_size))*batch_size, replace=True) train_data = GraphDataset(epoch_train_files) train_loader = torch_geometric.data.DataLoader(train_data, batch_size, shuffle=True) train_loss, train_kacc = process(policy, train_loader, top_k, optimizer) log(f"TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, train_kacc)]), logfile) # validate valid_loss, valid_kacc = process(policy, valid_loader, top_k, None) log(f"VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile) scheduler.step(valid_loss) if scheduler.num_bad_epochs == 0: torch.save(policy.state_dict(), pathlib.Path(running_dir)/'best_params.pkl') log(f" best model so far", logfile) elif scheduler.num_bad_epochs == 10: log(f" 10 epochs without improvement, decreasing learning rate", logfile) elif scheduler.num_bad_epochs == 20: log(f" 20 epochs without improvement, early stopping", logfile) break # load best parameters and run a final validation step policy.load_state_dict(torch.load(pathlib.Path(running_dir)/'best_params.pkl')) valid_loss, valid_kacc = process(policy, valid_loader, top_k, None) log(f"BEST VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile)