import torch as t
from torch import nn
from scipy.sparse import csgraph
from scipy.sparse.linalg import eigsh
from tree_utils import generate_plan_tree, TreeNet
import os
HIDDEN_DIM = 32
MAX_NODES = 200
DB = os.getenv('DB')
RowsNorm = 1e7
CKNodeTypes = [
"Projection",
"MergingAggregated",
"Exchange",
"Aggregating",
"Join",
"Filter",
"TableScan",
"Limit",
"Sorting",
"CTERef",
"Buffer",
"Union",
"EnforceSingleRow",
"Window",
"Values",
"PartitionTopN",
""
]
COMPARATORS = [
">",
"<",
"=",
">=",
"<=",
"!=",
"LIKE",
"IN",
"NOT LIKE"
]
NODE_DIM = 256
KNOB_DIM = 16
COL_NAME, MIN_MAX = [], []
class TPair(object):
def __init__(self, json_plan, knobs):
# self.tree_plan = generate_plan_tree(json_plan)
self.NodeTypes = CKNodeTypes
self._knobs = knobs
self._parse_plan(json_plan)
self.Knobs = t.eye(KNOB_DIM)
for i, v in enumerate(knobs):
if i < KNOB_DIM:
self.Knobs[i][i] = v
for i in range(KNOB_DIM - len(knobs)):
knobs.append(0.)
self.plat_knobs = t.Tensor(knobs)
def _node_to_vec(self, node):
if 'NodeType' not in node:
node['NodeType'] = ''
vec_len = len(CKNodeTypes) + 2
arr = [0. for _ in range(NODE_DIM)]
stats = {} if 'Statistic' not in node else node['Statistic']
arr[vec_len-1] = 0. if 'RowCount' not in stats else stats['RowCount'] / RowsNorm
arr[vec_len-2] = node['depth']
arr[self.NodeTypes.index(node['NodeType'])] = 1.
# concat other 1-hot encoding
if node['NodeType'] == 'TableScan' and 'Where' in node:
emb = encode_predicate(node['Where'])
elif node['NodeType'] == 'Join' and 'Condition' in node:
emb = encode_join(node['Condition'])
elif node['NodeType'] == 'Aggregating' and 'GroupByKeys' in node:
emb = encode_aggregate(node['GroupByKeys'])
else:
return arr
arr[vec_len:] = emb[:NODE_DIM-vec_len]
while len(arr) < NODE_DIM:
arr.append(0.)
return arr
def _parse_plan(self, root):
vec_len = NODE_DIM
nodes = [root]
res = []
mask = t.ones(MAX_NODES, MAX_NODES, dtype=bool)
root['depth'] = 1.
vis = []
while len(nodes) > 0:
next = nodes.pop()
arr = self._node_to_vec(next)
id = len(res)
res.append(arr)
if 'parent_id' in next:
mask[next['parent_id']][id] = False
mask[id][next['parent_id']] = False
mask[id][id] = False
if 'Children' in next:
for p in next['Children']:
if p['NodeId'] not in vis:
vis.append(p['NodeId'])
p['parent_id'] = id
p['depth'] = next['depth'] + 1
nodes.append(p)
for i in range(MAX_NODES - len(res)):
res.append([0. for _ in range(vec_len)])
self.Vecs = t.Tensor(res)
self.Mask1 = mask
def _spectral_encoding(self, k):
# compute indegree matrix according to Mask1
indegree = t.zeros(MAX_NODES, MAX_NODES)
for i in range(MAX_NODES):
count = 0
for j in range(MAX_NODES):
if self.Mask1[j][i] == t.tensor(True):
count = count + 1
indegree[i][i] = count
# compute Laplacian matrix
L = indegree - self.Mask1.int()
# compute eigenvalues and eigenvectors
laplacian_matrix = csgraph.laplacian(L.numpy(), normed=False)
eigenvalues, eigenvectors = eigsh(laplacian_matrix, k=k, which='SM') # 'SM' 表示最小特征值
# eigenvectors.shape: MAX_NODES x k
pass
class FFN(nn.Module):
def __init__(self, input=16, output=16):
super(FFN, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input, output),
nn.ReLU(),
nn.Linear(output, output)
)
def forward(self, vecs):
return self.layers(vecs)
class AttnEncoder(nn.Module):
def __init__(self, node_dim=NODE_DIM, knob_dim=KNOB_DIM, heads=4):
super(AttnEncoder, self).__init__()
self._node_dim = node_dim
self._knob_dim = knob_dim
self._heads = heads
self.nodes_attn = nn.MultiheadAttention(node_dim, heads, batch_first=True)
self.ffn1 = FFN(node_dim, HIDDEN_DIM)
self.norm1 = nn.LayerNorm(HIDDEN_DIM)
self.knobs_attn = nn.MultiheadAttention(knob_dim, heads, batch_first=True)
self.ffn2 = FFN(knob_dim, HIDDEN_DIM)
self.norm2 = nn.LayerNorm(HIDDEN_DIM)
self.cross_attn = nn.MultiheadAttention(HIDDEN_DIM, heads, batch_first=True)
self.norm3 = nn.LayerNorm(HIDDEN_DIM)
# self.test_linear = nn.Linear(node_dim + knob_dim, HIDDEN_DIM)
def forward(self, pairs: list):
# nodes = batch * max_nodes * node_dim
# knobs = batch * max_knobs * knob_dim
# mask1 = batch * 100 * 100, mask2 = batch * 16 * 16
nodes = t.stack([p.Vecs for p in pairs])
knobs = t.stack([p.Knobs for p in pairs])
mask1 = t.stack([p.Mask1 for _ in range(self._heads) for p in pairs])
x, weights = self.nodes_attn(nodes, nodes, nodes, attn_mask=mask1)
x = t.nan_to_num(x)
node_vecs = self.norm1(self.ffn1(x))
# return self.test_linear(t.cat([x.sum(dim=1), knobs.sum(dim=1)], dim=1))
x, weights = self.knobs_attn(knobs, knobs, knobs)
x = t.nan_to_num(x)
knob_vecs = self.norm2(self.ffn2(x))
x, _ = self.cross_attn(node_vecs, knob_vecs, knob_vecs)
x = t.max(x, dim=1).values
return self.norm3(x)
class TreeEncoder(nn.Module):
def __init__(self, node_dim=NODE_DIM, knob_dim=KNOB_DIM):
super(TreeEncoder, self).__init__()
self.treenet = TreeNet(node_dim, HIDDEN_DIM - knob_dim)
def forward(self, pairs: list):
# nodes = batch * max_nodes * node_dim
# knobs = batch * max_knobs * knob_dim
# mask1 = batch * 100 * 100, mask2 = batch * 16 * 16
trees = [p.tree_plan for p in pairs]
knobs = t.stack([p.plat_knobs for p in pairs])
tree_vecs = self.treenet(trees)
x = t.cat([tree_vecs, knobs], dim=1)
return x
def load_column_data(fname):
column_name, min_max_vals = [], []
with open(f"statistics/{fname}") as f:
lines = f.readlines()
for line in lines:
items = line.split()
column_name.append(items[1])
if len(items) == 4:
min_max_vals.append((int(items[2]), int(items[3])))
else:
min_max_vals.append(())
return column_name, min_max_vals
def encode_predicate(s: str):
global COL_NAME, MIN_MAX
if len(COL_NAME) == 0:
COL_NAME, MIN_MAX = load_column_data(DB)
res = []
predicates = []
flag = 0
p = ""
for c in s:
if c == '(':
p = "" if flag == 0 else p
flag += 1
elif c == ')':
if flag == 1:
predicates.append(p)
flag -= 1
elif c == ',':
pass
elif flag > 0:
p += c
if len(predicates) == 0 and 'Filters' not in s:
predicates.append(s)
for predicate in predicates:
items = predicate.split()
try:
if items[1] == 'NOT':
items[1] = 'NOT LIKE'
items.remove('LIKE')
emb = [0. for _ in range(len(COL_NAME) + len(COMPARATORS) + 6)]
emb[COL_NAME.index(items[0])] = 1.
emb[len(COL_NAME) + COMPARATORS.index(items[1])] = 1.
vals = MIN_MAX[COL_NAME.index(items[0])]
if len(vals) == 2:
emb[len(COL_NAME) + len(COMPARATORS)] = \
(int(items[2]) - vals[0]) / (vals[1] - vals[0])
res += emb
except Exception as e:
pass
return res
def encode_join(conds: list):
global COL_NAME, MIN_MAX
if len(COL_NAME) == 0:
COL_NAME, MIN_MAX = load_column_data(DB)
emb = [0. for _ in range(len(COL_NAME) + len(COMPARATORS) + 6)]
for s in conds:
tbls = s.split()
for tb in tbls:
if tb in COL_NAME:
emb[COL_NAME.index(tb)] = 1.
return emb
def encode_aggregate(keys: list):
global COL_NAME, MIN_MAX
if len(COL_NAME) == 0:
COL_NAME, MIN_MAX = load_column_data(DB)
emb = [0. for _ in range(len(COL_NAME) + len(COMPARATORS) + 6)]
for tb in keys:
if tb in COL_NAME:
emb[COL_NAME.index(tb)] = 1.
return emb
if __name__ == "__main__":
# with open("encoder/test_plan") as f:
# c = json.loads(f.read())
# p = TPair(c, [0.1, 0.3, 0.4, 0.34, 0.5])
# encoder = TreeEncoder(16, 16)
# print(encoder([p, p, p]))
x = encode_predicate("(s_comment LIKE '%Customer%') AND (s_comment LIKE '%Complaints%')")
print(x)
x = encode_join(["o_custkey == c_custkey"] )
print(x)