import torch as t
import numpy as np
import argparse
import json
import copy
import random
import os
from scipy.stats import norm, qmc
from connector import *
from network import LatentModel
# from encoder.utils import generate_plan_tree, plan_tree_to_vecs
from pair_encoder import TPair
DEVICE = 'cuda' if t.cuda.is_available() else "cpu"
DB = os.getenv('DB')
def eic(mu, stdvar, best, pr):
z = (best - mu) / stdvar
return pr * ((best - mu) * norm.cdf(z) + stdvar * norm.pdf(z))
def initialize_model(model_path):
state_dict = t.load(model_path)
model = LatentModel(32).to(DEVICE)
model.load_state_dict(state_dict['model'])
model.train(False)
return model
def initialize_data(workload_path):
with open(workload_path) as f:
queries = f.read().splitlines()
plans = []
for q in queries:
actual_db = DB
plans.append(get_query_plan(q, actual_db)['plan'])
return queries, plans
def initialize_observed_data(path):
pairs, elapsed = [], []
with open(path) as f:
line = f.readline()
while line:
items = line.split('\t')
stats = json.loads(items[3])
if len(stats) > 0:
pairs.append(TPair(json.loads(items[1])['plan'], json.loads(items[2])))
elapsed.append([stats['elapsed'], stats['fail']])
line = f.readline()
return pairs, elapsed
def get_configuration(path):
with open(path) as f:
conf = json.load(f)
return conf
def denormalize_knobs(conf: dict, values: list) -> dict:
res = {}
for i, key in enumerate(list(conf.keys())):
detail = conf[key]
if detail['type'] == 'int/continuous':
para = values[i]
res[key] = round(detail['lower'] + (detail['upper'] - detail['lower']) * para)
else:
idx = round(para)
res[key] = detail['value'][idx]
return res
def main_np(sqls, sample_file, model_file, knob_file, iters, output_file):
queries, plans = initialize_data(sqls)
model = initialize_model(model_file)
model.train(False)
obs_pairs, obs_elapseds = initialize_observed_data(sample_file)
samples = 2**10
conf = get_configuration(knob_file)
best_elapsed = [100. for _ in range(len(queries))]
for iter in range(iters):
elapsed = []
selected_confs = []
for i in range(len(queries)):
# recommend configuration
candidate_pairs = [TPair(plans[i], t.rand(len(conf)).tolist()) for _ in range(samples)]
preds_l, preds_e, _ = model(obs_pairs, t.Tensor(obs_elapseds), candidate_pairs)
latency = preds_l.mean
fail = preds_e.mean
cond = fail > 0.5
cond = cond.all(1)
latency[cond] = 1000.
idx = latency.argmin()
next_conf = denormalize_knobs(conf, candidate_pairs[idx]._knobs)
# print(next_conf)
# evaluate conf
actual_db = DB
stats = exec_query(queries[i], actual_db, next_conf)
if len(stats) == 0:
print('bad params')
stats['elapsed'] = 100.
print("query %d elapsed %f fail %d predict mean %f predict fail %f" % (i, stats['elapsed'], stats['fail'], latency[idx], fail[idx]))
obs_pairs.append(candidate_pairs[idx])
obs_elapseds.append([stats['elapsed'], stats['fail']])
if stats['fail'] == 0:
best_elapsed[i] = min(stats['elapsed'], best_elapsed[i])
with open(output_file, "w+") as fw:
fw.writelines([str(sum(best_elapsed))])
# collect data
# if iter == 0:
# obs_plan.clear()
# obs_conf.clear()
# obs_elapsed.clear()
print("best scores %f" % sum(best_elapsed))
sorted_elapsed = copy.deepcopy(best_elapsed)
sorted_elapsed.sort()
print(sorted_elapsed)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Configuration Tuning')
parser.add_argument('--knob_file', type=str, required=True,
help='specifies the path of knob file')
parser.add_argument('--sample_file', type=str, required=True,
help='specify the path of file that contains sampled data')
parser.add_argument('--db', type=str, required=True,
help='specifies the database')
parser.add_argument('--sqls', type=str, required=True,
help='specifies the path of SQL workloads')
parser.add_argument('--model', type=str, required=True,
help='specifies the path of model file, corresponding to --model_output in training phase')
parser.add_argument('--max_iteration', type=int, required=True,
help='specifies the maximum number of iterations for tuning')
parser.add_argument('--result_file', type=str, required=True,
help='specifies the file to save the tuning result')
args = parser.parse_args()
DB = args.db
main_np(args.sqls, args.sample_file, args.model, args.knob_file, args.max_iteration, args.result_file)