aqetuner / train.py
train.py
Raw
from tqdm import tqdm
from network import LatentModel
from torch.utils.data import Dataset, DataLoader

import torch as t
from torch.utils.data import DataLoader
from pair_encoder import TPair
import argparse
import os
import json
import random
import time

DEVICE = 'cuda' if t.cuda.is_available() else "cpu"
DB = os.getenv('DB')

# Define Dataset Loader Here
class CustomDataset(Dataset):
    def __init__(self, pairs_ctx, elapsed_ctx, pairs_tar, elapsed_tar):
        self.pairs_ctx = pairs_ctx
        self.elapsed_ctx = elapsed_ctx
        self.pairs_tar = pairs_tar
        self.elapsed_tar = elapsed_tar

    def __getitem__(self, index):
        return self.pairs_ctx[index], self.elapsed_ctx[index], \
            self.pairs_tar[index], self.elapsed_tar[index]

    def __len__(self):
        return len(self.pairs_ctx)

def load_dataset(path, data_size=300):
    pairs, elapsed = [], []
    with open(path) as f:
        line = f.readline()
        while line:
            items = line.split('\t')
            stats = json.loads(items[3])
            if len(stats) > 0:
                pairs.append(TPair(json.loads(items[1])['plan'], json.loads(items[2])))
                elapsed.append([stats['elapsed'], stats['fail']])
            line = f.readline()
    
    item_len = len(pairs)
    ctx_num = item_len // 2
    tar_num = item_len // 4
    print("read dataset, size %d" % item_len)
    pairs_ctx, elapseds_ctx, pairs_tar, elapseds_tar = [], [], [], []
    for _ in range(data_size):
        pair_ctx, elapsed_ctx, pair_tar, elapsed_tar = [], [], [], []
        for i in range(ctx_num):
            idx = random.choice(range(item_len))
            pair_ctx.append(pairs[idx])
            elapsed_ctx.append(elapsed[idx])
        for i in range(tar_num):
            idx = random.choice(range(item_len))
            pair_tar.append(pairs[idx])
            elapsed_tar.append(elapsed[idx])
        pairs_ctx.append(pair_ctx)
        elapseds_ctx.append(t.Tensor(elapsed_ctx))
        pairs_tar.append(pair_tar)
        elapseds_tar.append(t.Tensor(elapsed_tar))
    
    return CustomDataset(pairs_ctx, elapseds_ctx, pairs_tar, elapseds_tar)


def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
    lr = 0.001 * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def verify(dict_name):
    dataset = load_dataset(f"data/{DB}_samples", 1)
    dloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=True)
    for _, data in enumerate(dloader):
        x1, y1, x2, y2 = data
        state_dict = t.load(dict_name)
        model = LatentModel(32).to(DEVICE)
        model.load_state_dict(state_dict['model'])
        model.train(False)
        start = time.time()
        res, status, loss = model(x1, y1, x2, None)
        print("cost %f" % (time.time() - start))
        latency = y2[:, 0].unsqueeze(-1)
        print(y2, t.cat([res.stddev, res.mean], dim=1))
        mae = t.abs(res.mean - latency).mean()
        print("mae %f" % mae)
        break

        
def main(file_name, output_path, epochs, lr):
    model = LatentModel(32).to(DEVICE)
    model.train()
    
    optim = t.optim.Adam(model.parameters(), lr=lr)
    global_step = 0
    train_dataset = load_dataset(file_name)
    for epoch in range(epochs):
        dloader = DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=True)
        pbar = tqdm(dloader)
        loss_accum = 0.
        for i, data in enumerate(pbar):
            global_step += 1
            adjust_learning_rate(optim, global_step)
            x1, y1, x2, y2 = data
            
            # pass through the latent model
            y_pred_l, y_pred_e, loss = model(x1, y1, x2, y2)
            loss_accum += loss
            
            # Training step
            optim.zero_grad()
            loss.backward()
            optim.step()
            # print("loss : {} kl: {}".format(loss.mean(), kl.mean()) ) 
            # Logging
            # writer.add_scalars('training_loss',{
            #         'loss':loss,
            #         'kl':kl.mean(),

            #     }, global_step)
            
        # save model by each epoch    
        t.save({'model':model.state_dict(),
                                 'optimizer':optim.state_dict()},
                                os.path.join(output_path))
        print("loss %f" % (loss_accum / len(train_dataset)))
        
        
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training Knob-Plan Encoder & Dual-Task Predictor')
    parser.add_argument('--sample_file', type=str, required=True,
                        help='specifies the file that contains sampled data')
    parser.add_argument('--model_output', type=str, required=True,
                        help='specifies the path where the trained model is output')
    parser.add_argument('--epoch', type=int, required=True,
                        help='specifies training epochs')
    parser.add_argument('--lr', type=float, required=True,
                        help='specifies the learning rate of training')

    args = parser.parse_args()
    
    main(f"data/{args.sample_file}", args.model_output, args.epoch, args.lr)
    # verify("checkpoint/stats.pth.tar")