aqetuner / pair_encoder.py
pair_encoder.py
Raw

import torch as t
from torch import nn
from scipy.sparse import csgraph
from scipy.sparse.linalg import eigsh
from tree_utils import generate_plan_tree, TreeNet

import os


HIDDEN_DIM = 32
MAX_NODES = 200
DB = os.getenv('DB')

RowsNorm = 1e7

CKNodeTypes = [
    "Projection",
    "MergingAggregated",
    "Exchange",
    "Aggregating",
    "Join",
    "Filter",
    "TableScan",
    "Limit",
    "Sorting",
    "CTERef",
    "Buffer",
    "Union",
    "EnforceSingleRow",
    "Window",
    "Values",
    "PartitionTopN",
    ""
]

COMPARATORS = [
    ">",
    "<",
    "=",
    ">=",
    "<=",
    "!=",
    "LIKE",
    "IN",
    "NOT LIKE"
]

NODE_DIM = 256
KNOB_DIM = 16

COL_NAME, MIN_MAX = [], []

class TPair(object):

    def __init__(self, json_plan, knobs):
        # self.tree_plan = generate_plan_tree(json_plan)
        self.NodeTypes = CKNodeTypes
        self._knobs = knobs
        self._parse_plan(json_plan)
        self.Knobs = t.eye(KNOB_DIM)
        for i, v in enumerate(knobs):
            if i < KNOB_DIM:
                self.Knobs[i][i] = v
        for i in range(KNOB_DIM - len(knobs)):
            knobs.append(0.)
        self.plat_knobs = t.Tensor(knobs)

    def _node_to_vec(self, node):
        if 'NodeType' not in node:
            node['NodeType'] = ''
        vec_len = len(CKNodeTypes) + 2
        arr = [0. for _ in range(NODE_DIM)]
        stats = {} if 'Statistic' not in node else node['Statistic']
        arr[vec_len-1] = 0. if 'RowCount' not in stats else stats['RowCount'] / RowsNorm
        arr[vec_len-2] = node['depth']
        arr[self.NodeTypes.index(node['NodeType'])] = 1.

        # concat other 1-hot encoding
        if node['NodeType'] == 'TableScan' and 'Where' in node:
            emb = encode_predicate(node['Where'])
        elif node['NodeType'] == 'Join' and 'Condition' in node:
            emb = encode_join(node['Condition'])
        elif node['NodeType'] == 'Aggregating' and 'GroupByKeys' in node:
            emb = encode_aggregate(node['GroupByKeys'])
        else:
            return arr
        
        arr[vec_len:] = emb[:NODE_DIM-vec_len]
        while len(arr) < NODE_DIM:
            arr.append(0.)
        return arr

    def _parse_plan(self, root):
        vec_len = NODE_DIM
        nodes = [root]
        res = []
        mask = t.ones(MAX_NODES, MAX_NODES, dtype=bool)
        root['depth'] = 1.

        vis = []
        while len(nodes) > 0:
            next = nodes.pop()
            arr = self._node_to_vec(next)
            id = len(res)
            res.append(arr)
            if 'parent_id' in next:
                mask[next['parent_id']][id] = False
                mask[id][next['parent_id']] = False
                mask[id][id] = False
            if 'Children' in next:
                for p in next['Children']: 
                    if p['NodeId'] not in vis:
                        vis.append(p['NodeId'])
                        p['parent_id'] = id 
                        p['depth'] = next['depth'] + 1
                        nodes.append(p)
        
        for i in range(MAX_NODES - len(res)):
            res.append([0. for _ in range(vec_len)])
        
        self.Vecs = t.Tensor(res)
        self.Mask1 = mask

    def _spectral_encoding(self, k):
        # compute indegree matrix according to Mask1
        indegree = t.zeros(MAX_NODES, MAX_NODES)
        for i in range(MAX_NODES):
            count = 0
            for j in range(MAX_NODES):
                if self.Mask1[j][i] == t.tensor(True):
                    count = count + 1
            indegree[i][i] = count
        # compute Laplacian matrix
        L = indegree - self.Mask1.int()
        # compute eigenvalues and eigenvectors
        laplacian_matrix = csgraph.laplacian(L.numpy(), normed=False)
        eigenvalues, eigenvectors = eigsh(laplacian_matrix, k=k, which='SM')  # 'SM' 表示最小特征值
        # eigenvectors.shape: MAX_NODES x k
        pass

class FFN(nn.Module):

    def __init__(self, input=16, output=16):
        super(FFN, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input, output),
            nn.ReLU(),
            nn.Linear(output, output)
        )

    def forward(self, vecs):
        return self.layers(vecs)

class AttnEncoder(nn.Module):

    def __init__(self, node_dim=NODE_DIM, knob_dim=KNOB_DIM, heads=4):
        super(AttnEncoder, self).__init__()
        self._node_dim = node_dim
        self._knob_dim = knob_dim
        self._heads = heads
        self.nodes_attn = nn.MultiheadAttention(node_dim, heads, batch_first=True)
        self.ffn1 = FFN(node_dim, HIDDEN_DIM)
        self.norm1 = nn.LayerNorm(HIDDEN_DIM)
        self.knobs_attn = nn.MultiheadAttention(knob_dim, heads, batch_first=True)
        self.ffn2 = FFN(knob_dim, HIDDEN_DIM)
        self.norm2 = nn.LayerNorm(HIDDEN_DIM)
        self.cross_attn = nn.MultiheadAttention(HIDDEN_DIM, heads, batch_first=True)
        self.norm3 = nn.LayerNorm(HIDDEN_DIM)

        # self.test_linear = nn.Linear(node_dim + knob_dim, HIDDEN_DIM)

    def forward(self, pairs: list):
        # nodes = batch * max_nodes * node_dim
        # knobs = batch * max_knobs * knob_dim
        # mask1 = batch * 100 * 100, mask2 = batch * 16 * 16
        nodes = t.stack([p.Vecs for p in pairs])
        knobs = t.stack([p.Knobs for p in pairs])
        mask1 = t.stack([p.Mask1 for _ in range(self._heads) for p in pairs])
        x, weights = self.nodes_attn(nodes, nodes, nodes, attn_mask=mask1)
        x = t.nan_to_num(x)
        node_vecs = self.norm1(self.ffn1(x))
        # return self.test_linear(t.cat([x.sum(dim=1), knobs.sum(dim=1)], dim=1))
        x, weights = self.knobs_attn(knobs, knobs, knobs)
        x = t.nan_to_num(x)
        knob_vecs = self.norm2(self.ffn2(x))
        x, _ = self.cross_attn(node_vecs, knob_vecs, knob_vecs)
        x = t.max(x, dim=1).values
        return self.norm3(x)


class TreeEncoder(nn.Module):

    def __init__(self, node_dim=NODE_DIM, knob_dim=KNOB_DIM):
        super(TreeEncoder, self).__init__()
        self.treenet = TreeNet(node_dim, HIDDEN_DIM - knob_dim)

    def forward(self, pairs: list):
        # nodes = batch * max_nodes * node_dim
        # knobs = batch * max_knobs * knob_dim
        # mask1 = batch * 100 * 100, mask2 = batch * 16 * 16
        trees = [p.tree_plan for p in pairs]
        knobs = t.stack([p.plat_knobs for p in pairs])
        tree_vecs = self.treenet(trees)
        x = t.cat([tree_vecs, knobs], dim=1)
        return x


def load_column_data(fname):
    column_name, min_max_vals = [], []
    with open(f"statistics/{fname}") as f:
        lines = f.readlines()
    for line in lines:
        items = line.split()
        column_name.append(items[1])
        if len(items) == 4:
            min_max_vals.append((int(items[2]), int(items[3])))
        else:
            min_max_vals.append(())
    return column_name, min_max_vals


def encode_predicate(s: str):
    global COL_NAME, MIN_MAX
    if len(COL_NAME) == 0:
        COL_NAME, MIN_MAX = load_column_data(DB)
    res = []
    predicates = []
    flag = 0
    p = ""
    for c in s:
        if c == '(':
            p = "" if flag == 0 else p
            flag += 1
        elif c == ')':
            if flag == 1:
                predicates.append(p)
            flag -= 1
        elif c == ',':
            pass
        elif flag > 0:
            p += c
    if len(predicates) == 0 and 'Filters' not in s:
        predicates.append(s)
    
    for predicate in predicates:
        items = predicate.split()
        try:
            if items[1] == 'NOT':
                items[1] = 'NOT LIKE'
                items.remove('LIKE')
            emb = [0. for _ in range(len(COL_NAME) + len(COMPARATORS) + 6)]
            emb[COL_NAME.index(items[0])] = 1.
            emb[len(COL_NAME) + COMPARATORS.index(items[1])] = 1.
            vals = MIN_MAX[COL_NAME.index(items[0])]
            if len(vals) == 2:
                emb[len(COL_NAME) + len(COMPARATORS)] = \
                    (int(items[2]) - vals[0]) / (vals[1] - vals[0])
            res += emb
        except Exception as e:
            pass
    return res

def encode_join(conds: list):
    global COL_NAME, MIN_MAX
    if len(COL_NAME) == 0:
        COL_NAME, MIN_MAX = load_column_data(DB)
    emb = [0. for _ in range(len(COL_NAME) + len(COMPARATORS) + 6)]
    for s in conds:
        tbls = s.split()
        for tb in tbls:
            if tb in COL_NAME:
                emb[COL_NAME.index(tb)] = 1.
    return emb

def encode_aggregate(keys: list):
    global COL_NAME, MIN_MAX
    if len(COL_NAME) == 0:
        COL_NAME, MIN_MAX = load_column_data(DB)
    emb = [0. for _ in range(len(COL_NAME) + len(COMPARATORS) + 6)]
    for tb in keys:
        if tb in COL_NAME:
            emb[COL_NAME.index(tb)] = 1.
    return emb


if __name__ == "__main__":
    # with open("encoder/test_plan") as f:
    #     c = json.loads(f.read())
    # p = TPair(c, [0.1, 0.3, 0.4, 0.34, 0.5])
    # encoder = TreeEncoder(16, 16)
    # print(encoder([p, p, p]))
    x = encode_predicate("(s_comment LIKE '%Customer%') AND (s_comment LIKE '%Complaints%')")
    print(x)
    x = encode_join(["o_custkey == c_custkey"] )
    print(x)