import threading import queue import json import time import argparse import math import time import numpy as np from scipy.stats import qmc from connector import exec_query, get_query_plan class SinglePSOSampler(threading.Thread): def __init__(self, sql, db, knobs: dict, p_num, res_q: queue.Queue, arg_q: queue.Queue, is_consumer, records: list, max_tries=9, *args, **kwargs): self._sql = sql self._db = db self._knobs = knobs self._res_q = res_q self._arg_q = arg_q self._p_num = p_num self._is_consumer = is_consumer self._records = records self._max_tries = max_tries super().__init__(*args, **kwargs) def produce(self): metrics = ['elapsed'] global_best = [10000., 10000.,] global_best_loc = [None, None] local_best = [10000., 10000., 10000.,] local_best_loc = [None, None, None] speed = [] knobs_num = len(self._knobs) default_values = np.random.rand(knobs_num) # for i, k in enumerate(list(self._knobs.keys())): # default_values[i] = (self._knobs[k]['value'] - self._knobs[k]['lower']) \ # / (self._knobs[k]['upper'] - self._knobs[k]['lower']) sampler = qmc.Sobol(d=knobs_num) for i in range(len(local_best)): # if i == 0: # self._arg_q.put_nowait((i, default_values)) # else: self._arg_q.put_nowait((i, sampler.random(1)[0])) speed.append(np.random.uniform(-0.1, 0.1, (1, knobs_num))[0]) # wait for any query finish count = self._max_tries while True: count -= 1 id, values, stats = self._res_q.get() print("accept id {} stats {}".format(id, stats)) self._records.append((list(values), stats)) if count <= 0: break # update global best if len(stats) > 0: for i, j in enumerate(metrics): if stats[j] < global_best[i]: global_best[i] = stats[j] global_best_loc[i] = values # update local best if stats['elapsed'] < local_best[id]: local_best[id] = stats['elapsed'] local_best_loc[id] = values if stats['fail'] <= 1: speed[id] = np.random.uniform(-0.1, 0.1, (1, knobs_num))[0] values = sampler.random(1)[0] else: # update speed speed[id] = speed[id] + 0.5 * np.random.rand() * (local_best_loc[id] - values) \ + 0.5 * np.random.rand() * (global_best_loc[0] - values) values += speed[id] values = np.clip(values, 0, 1) self._arg_q.put_nowait((id, values)) def consume(self): while True: time.sleep(0.01) try: id, values = self._arg_q.get(timeout=3) # knobs to settings settings = {} for i, k in enumerate(self._knobs.keys()): if self._knobs[k]['type'] == 'int/continuous': settings[k] = \ round(self._knobs[k]['lower'] + \ values[i] * (self._knobs[k]['upper'] - self._knobs[k]['lower'])) else: idx = math.floor(values[i] * len(self._knobs[k]['value'])) if idx == len(self._knobs[k]['value']): idx -=1 settings[k] = self._knobs[k]['value'][idx] stats = exec_query(self._sql, self._db, settings) self._res_q.put((id, values, stats)) except Exception as e: break def run(self): self.consume() if self._is_consumer else self.produce() class PSOSampler(object): def __init__(self, sqls, db, knobs, fname='data/sample'): self.sqls = sqls self.db = db self.knobs = knobs self.knobs_num = len(knobs) self.fname = fname def sample_on_sql(self, sql: str, max_threads=3): q1 = queue.Queue() q2 = queue.Queue() t = [] records = [] exec_query(sql, self.db) try: analyze_plan = json.dumps(get_query_plan(sql, self.db, analyze=True)) except Exception as e: analyze_plan = {} general_plan = json.dumps(get_query_plan(sql, self.db, analyze=False)) for i in range(max_threads): sampler = SinglePSOSampler(sql, self.db, knobs, 100, q1, q2, True, records) sampler.start() t.append(sampler) sampler = SinglePSOSampler(sql, self.db, knobs, 100, q1, q2, False, records) sampler.start() t.append(sampler) for i in range(len(t)): try: t[i].join() except Exception as e: break with open(self.fname, 'a+') as f: for values, stats in records: f.write(analyze_plan + "\t") f.write(general_plan + "\t") f.write(json.dumps(values) + "\t") f.write(json.dumps(stats) + "\n") def run_samples(self, sqls: list, max_threads=3): for i, sql in enumerate(sqls): print("sample on index %d sql %s" % (i, sql)) self.sample_on_sql(sql, max_threads) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--knob_file', '-k') parser.add_argument('--db', '-db') parser.add_argument('--sqls', '-s') parser.add_argument('--output', '-o') parser.add_argument('--threads', '-t', type=int, default=1) args = parser.parse_args() print("start collecting warm-start samples.") f = open(args.knob_file) knobs = json.load(f) f.close() sampler = PSOSampler([], args.db, knobs, args.output) with open(args.sqls) as f: lines = f.read().splitlines() sampler.run_samples(lines, args.threads)