EECE571F-project / ml4co-competition / baseline / dual / train_files / 02_train.py
02_train.py
Raw
import os
import sys
import glob
import argparse
import pathlib
import numpy as np
sys.path.append('.')


def pretrain(policy, pretrain_loader):
    """
    Pre-trains all PreNorm layers in the model.

    Parameters
    ----------
    policy : torch.nn.Module
        Model to pre-train.
    pretrain_loader : torch_geometric.data.DataLoader
        Pre-loaded dataset of pre-training samples.

    Returns
    -------
    i : int
        Number of pre-trained layers.
    """
    policy.pre_train_init()
    i = 0
    while True:
        for batch in pretrain_loader:
            batch.to(device)
            if not policy.pre_train(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features):
                break

        if policy.pre_train_next() is None:
            break
        i += 1
    return i


def process(policy, data_loader, top_k=[1, 3, 5, 10], optimizer=None):
    """
    Process samples. If an optimizer is given, also train on those samples.

    Parameters
    ----------
    policy : torch.nn.Module
        Model to train/evaluate.
    data_loader : torch_geometric.data.DataLoader
        Pre-loaded dataset of training samples.
    top_k : list
        Accuracy will be computed for the top k elements, for k in this list.
    optimizer : torch.optim (optional)
        Optimizer object. If not None, will be used for updating the model parameters.

    Returns
    -------
    mean_loss : float in [0, 1e+20]
        Mean cross entropy loss.
    mean_kacc : np.ndarray
        Mean top k accuracy, for k in the user-provided list top_k.
    """
    mean_loss = 0
    mean_kacc = np.zeros(len(top_k))

    n_samples_processed = 0
    with torch.set_grad_enabled(optimizer is not None):
        for batch in data_loader:
            batch = batch.to(device)
            logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features)
            logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)
            cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean')

            # if an optimizer is provided, update parameters
            if optimizer is not None:
                optimizer.zero_grad()
                cross_entropy_loss.backward()
                optimizer.step()

            true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates)
            true_bestscore = true_scores.max(dim=-1, keepdims=True).values

            # calculate top k accuracy
            kacc = []
            for k in top_k:
                # check if there are at least k candidates
                if logits.size()[-1] < k:
                    kacc.append(1.0)
                    continue
                pred_top_k = logits.topk(k).indices
                pred_top_k_true_scores = true_scores.gather(-1, pred_top_k)
                accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item()
                kacc.append(accuracy)
            kacc = np.asarray(kacc)

            mean_loss += cross_entropy_loss.item() * batch.num_graphs
            mean_kacc += kacc * batch.num_graphs
            n_samples_processed += batch.num_graphs

    mean_loss /= n_samples_processed
    mean_kacc /= n_samples_processed
    return mean_loss, mean_kacc


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        choices=['item_placement', 'load_balancing', 'anonymous'],
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=int,
        default=0,
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=-1,
    )
    args = parser.parse_args()

    # hyper parameters
    max_epochs = 100
    batch_size = 32
    pretrain_batch_size = 128
    valid_batch_size = 128
    lr = 1e-3
    top_k = [1, 3, 5, 10]

    # get sample directory
    if args.problem == 'item_placement':
        train_files = glob.glob('baseline/dual/train_files/samples/1_item_placement/train/sample_*.pkl')
        valid_files = glob.glob('baseline/dual/train_files/samples/1_item_placement/valid/sample_*.pkl')
        running_dir = 'baseline/dual/train_files/trained_models/item_placement'

    elif args.problem == 'load_balancing':
        train_files = glob.glob('train_files/samples/2_load_balancing/train/sample_*.pkl')
        valid_files = glob.glob('train_files/samples/2_load_balancing/valid/sample_*.pkl')
        running_dir = 'train_files/trained_models/load_balancing'

    elif args.problem == 'anonymous':
        train_files = glob.glob('train_files/samples/3_anonymous/train/sample_*.pkl')
        valid_files = glob.glob('train_files/samples/3_anonymous/valid/sample_*.pkl')
        running_dir = 'train_files/trained_models/anonymous'
    
    elif args.problem == 'set_cover':
        train_files = glob.glob('baseline/dual/train_files/samples/4_set_cover/train/sample_*.pkl')
        valid_files = glob.glob('baseline/dual/train_files/samples/4_set_cover/valid/sample_*.pkl')
        running_dir = 'baseline/dual/train_files/trained_models/set_cover'
    
    elif args.problem == 'combinatorial_auction':
        train_files = glob.glob('baseline/dual/train_files/samples/5_combinatorial_auction/train/sample_*.pkl')
        valid_files = glob.glob('baseline/dual/train_files/samples/5_combinatorial_auction/valid/sample_*.pkl')
        running_dir = 'baseline/dual/train_files/trained_models/combinatorial_auction'
    
    elif args.problem == 'capacitated_facility_location':
        train_files = glob.glob('baseline/dual/train_files/samples/6_capacitated_facility_location/train/sample_*.pkl')
        valid_files = glob.glob('baseline/dual/train_files/samples/6_capacitated_facility_location/valid/sample_*.pkl')
        running_dir = 'baseline/dual/train_files/trained_models/capacitated_facility_location'
    
    elif args.problem == 'independent_set':
        train_files = glob.glob('baseline/dual/train_files/samples/7_independent_set/train/sample_*.pkl')
        valid_files = glob.glob('baseline/dual/train_files/samples/7_independent_set/valid/sample_*.pkl')
        running_dir = 'baseline/dual/train_files/trained_models/independent_set'
    else:
        raise NotImplementedError

    pretrain_files = [f for i, f in enumerate(train_files) if i % 10 == 0]

    # working directory setup
    os.makedirs(running_dir, exist_ok=True)

    # cuda setup
    if args.gpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        device = "cpu"
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
        device = f"cuda:0"

    # import pytorch **after** cuda setup
    import torch
    import torch.nn.functional as F
    import torch_geometric
    from baseline.dual.train_files.utilities import log, pad_tensor, GraphDataset, Scheduler
    sys.path.insert(0,'.')
    from baseline.dual.model import GNNPolicy

    # randomization setup
    rng = np.random.RandomState(args.seed)
    torch.manual_seed(args.seed)

    # logging setup
    logfile = os.path.join(running_dir, 'train_log.txt')
    if os.path.exists(logfile):
        os.remove(logfile)

    log(f"max_epochs: {max_epochs}", logfile)
    log(f"batch_size: {batch_size}", logfile)
    log(f"pretrain_batch_size: {pretrain_batch_size}", logfile)
    log(f"valid_batch_size : {valid_batch_size }", logfile)
    log(f"lr: {lr}", logfile)
    log(f"top_k: {top_k}", logfile)
    log(f"gpu: {args.gpu}", logfile)
    log(f"seed {args.seed}", logfile)

    # data setup
    valid_data = GraphDataset(valid_files)
    pretrain_data = GraphDataset(pretrain_files)
    valid_loader = torch_geometric.loader.DataLoader(valid_data, batch_size=valid_batch_size, shuffle=False)
    pretrain_loader = torch_geometric.loader.DataLoader(pretrain_data, batch_size=pretrain_batch_size, shuffle=False)
    
    policy = GNNPolicy().to(device)
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    scheduler = Scheduler(optimizer, mode='min', patience=10, factor=0.2, verbose=True)

    for epoch in range(max_epochs + 1):
        log(f"EPOCH {epoch}...", logfile)
        if epoch == 0:
            n = pretrain(policy, pretrain_loader)
            log(f"PRETRAINED {n} LAYERS", logfile)
        else:
        # epoch_train_files = rng.choice(train_files, int(np.floor(10000/batch_size))*batch_size, replace=False)
            train_data = GraphDataset(train_files)
            train_loader = torch_geometric.loader.DataLoader(train_data, batch_size, shuffle=True)
            train_loss, train_kacc = process(policy, train_loader, top_k, optimizer)
            log(f"TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, train_kacc)]), logfile)

        # validate
        valid_loss, valid_kacc = process(policy, valid_loader, top_k, None)
        log(f"VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile)

        scheduler.step(valid_loss)
        if scheduler.num_bad_epochs == 0:
            torch.save(policy.state_dict(), pathlib.Path(running_dir)/'best_params.pkl')
            log(f"  best model so far", logfile)
        elif scheduler.num_bad_epochs == 10:
            log(f"  10 epochs without improvement, decreasing learning rate", logfile)
        elif scheduler.num_bad_epochs == 20:
            log(f"  20 epochs without improvement, early stopping", logfile)
            break

    # load best parameters and run a final validation step
    policy.load_state_dict(torch.load(pathlib.Path(running_dir)/'best_params.pkl'))
    valid_loss, valid_kacc = process(policy, valid_loader, top_k, None)
    log(f"BEST VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile)