import json
import numpy as np
import torch
import torch.optim
import tcnn
import utils
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
CUDA = torch.cuda.is_available()
CHANNELS = 9
EPOCHS = 100
def same_seed(seed):
'''Fixes random number generator seeds for reproducibility.'''
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def collate(x):
trees = []
targets = []
for tree, target in x:
trees.append(tree)
targets.append(target)
targets = torch.tensor(targets)
return trees, targets
class TCNNRegression:
def __init__(self):
self.trained = False
def fit(self, X, Y):
if isinstance(Y, list):
Y = np.array(Y).reshape(-1, 1).astype(np.float32)
pairs = list(zip(X, Y))
dataset = DataLoader(pairs, batch_size=32,
shuffle=True, collate_fn=collate)
self._net = tcnn.TreeNet(CHANNELS)
self._net.train()
optimizer = torch.optim.Adam(self._net.parameters())
loss_fn = torch.nn.MSELoss()
losses = []
for epoch in range(1, 1+EPOCHS):
accum = 0.
for x, y in dataset:
y_pred = self._net(x)
loss = loss_fn(y_pred, y)
accum += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
accum /= len(dataset)
losses.append(accum)
print("epoch {} loss {}".format(epoch, accum))
self.trained = True
def predict(self, X):
assert self.trained
self._net.eval()
pred = self._net(X).cpu().detach().numpy()
return pred
if __name__ == "__main__":
same_seed(1145141101)
X, Y = [], []
SPLIT_POS = 4000
with open("../data/plan_test/imdb_data") as f:
lines = f.readlines()
for l in lines:
items = l.split('\t')
plan = utils.generate_plan_tree(json.loads(items[1])['plan'])
latency = float(items[2])
X.append(plan)
Y.append(latency)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
model = TCNNRegression()
model.fit(X_train, y_train)
preds = model.predict(X_test).flatten()
mse = mean_squared_error(y_test, preds)
mae = mean_absolute_error(y_test, preds)
print("mean squared error: {}".format(mse))
print("mean absolute error: {}".format(mae))