from tqdm import tqdm
from network import LatentModel
from torch.utils.data import Dataset, DataLoader
import torch as t
from torch.utils.data import DataLoader
from pair_encoder import TPair
import argparse
import os
import json
import random
import time
DEVICE = 'cuda' if t.cuda.is_available() else "cpu"
DB = os.getenv('DB')
# Define Dataset Loader Here
class CustomDataset(Dataset):
def __init__(self, pairs_ctx, elapsed_ctx, pairs_tar, elapsed_tar):
self.pairs_ctx = pairs_ctx
self.elapsed_ctx = elapsed_ctx
self.pairs_tar = pairs_tar
self.elapsed_tar = elapsed_tar
def __getitem__(self, index):
return self.pairs_ctx[index], self.elapsed_ctx[index], \
self.pairs_tar[index], self.elapsed_tar[index]
def __len__(self):
return len(self.pairs_ctx)
def load_dataset(path, data_size=300):
pairs, elapsed = [], []
with open(path) as f:
line = f.readline()
while line:
items = line.split('\t')
stats = json.loads(items[3])
if len(stats) > 0:
pairs.append(TPair(json.loads(items[1])['plan'], json.loads(items[2])))
elapsed.append([stats['elapsed'], stats['fail']])
line = f.readline()
item_len = len(pairs)
ctx_num = item_len // 2
tar_num = item_len // 4
print("read dataset, size %d" % item_len)
pairs_ctx, elapseds_ctx, pairs_tar, elapseds_tar = [], [], [], []
for _ in range(data_size):
pair_ctx, elapsed_ctx, pair_tar, elapsed_tar = [], [], [], []
for i in range(ctx_num):
idx = random.choice(range(item_len))
pair_ctx.append(pairs[idx])
elapsed_ctx.append(elapsed[idx])
for i in range(tar_num):
idx = random.choice(range(item_len))
pair_tar.append(pairs[idx])
elapsed_tar.append(elapsed[idx])
pairs_ctx.append(pair_ctx)
elapseds_ctx.append(t.Tensor(elapsed_ctx))
pairs_tar.append(pair_tar)
elapseds_tar.append(t.Tensor(elapsed_tar))
return CustomDataset(pairs_ctx, elapseds_ctx, pairs_tar, elapseds_tar)
def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
lr = 0.001 * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def verify(dict_name):
dataset = load_dataset(f"data/{DB}_samples", 1)
dloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=True)
for _, data in enumerate(dloader):
x1, y1, x2, y2 = data
state_dict = t.load(dict_name)
model = LatentModel(32).to(DEVICE)
model.load_state_dict(state_dict['model'])
model.train(False)
start = time.time()
res, status, loss = model(x1, y1, x2, None)
print("cost %f" % (time.time() - start))
latency = y2[:, 0].unsqueeze(-1)
print(y2, t.cat([res.stddev, res.mean], dim=1))
mae = t.abs(res.mean - latency).mean()
print("mae %f" % mae)
break
def main(file_name, output_path, epochs, lr):
model = LatentModel(32).to(DEVICE)
model.train()
optim = t.optim.Adam(model.parameters(), lr=lr)
global_step = 0
train_dataset = load_dataset(file_name)
for epoch in range(epochs):
dloader = DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=True)
pbar = tqdm(dloader)
loss_accum = 0.
for i, data in enumerate(pbar):
global_step += 1
adjust_learning_rate(optim, global_step)
x1, y1, x2, y2 = data
# pass through the latent model
y_pred_l, y_pred_e, loss = model(x1, y1, x2, y2)
loss_accum += loss
# Training step
optim.zero_grad()
loss.backward()
optim.step()
# print("loss : {} kl: {}".format(loss.mean(), kl.mean()) )
# Logging
# writer.add_scalars('training_loss',{
# 'loss':loss,
# 'kl':kl.mean(),
# }, global_step)
# save model by each epoch
t.save({'model':model.state_dict(),
'optimizer':optim.state_dict()},
os.path.join(output_path))
print("loss %f" % (loss_accum / len(train_dataset)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training Knob-Plan Encoder & Dual-Task Predictor')
parser.add_argument('--sample_file', type=str, required=True,
help='specifies the file that contains sampled data')
parser.add_argument('--model_output', type=str, required=True,
help='specifies the path where the trained model is output')
parser.add_argument('--epoch', type=int, required=True,
help='specifies training epochs')
parser.add_argument('--lr', type=float, required=True,
help='specifies the learning rate of training')
args = parser.parse_args()
main(f"data/{args.sample_file}", args.model_output, args.epoch, args.lr)
# verify("checkpoint/stats.pth.tar")