aqetuner / encoder / attn_encoder.py
attn_encoder.py
Raw
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import json
import queue
import shap
from sklearn.tree import DecisionTreeRegressor
from utils import flatten_tree, generate_plan_tree, _pad_and_combine, exec_analyze_query, generate_latency


def left_child(x):
    if len(x) != 3:
        return None
    return x[1]


def right_child(x):
    if len(x) != 3:
        return None
    return x[2]


def features(x):
    return x[0]


class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask, d_k=3):
        # inputs
        # Q: [batch_size x n_heads x len_q x d_k]
        # K: [batch_size x n_heads x len_k x d_k]
        # V: [batch_size x n_heads x len_k x d_v]
        # scores : [batch_size x n_heads x len_q x len_k]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)

        # attn_mask
        if(attn_mask != None):
            # scores += attn_mask
            attn_mask = attn_mask.float()
            scores = torch.matmul(scores, attn_mask.transpose(-1, -2))
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=9, n_heads=3, d_q=3, d_k=3, d_v=3):
        # d_model = nheads * d_q
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_q = d_q
        self.d_k = d_k
        self.d_v = d_v
        # param matrix
        self.W_Q = nn.Linear(d_model, d_q * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
        self.linear = nn.Linear(n_heads * d_v, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, Q, K, V, attn_mask):
        # inputs: Q: [batch_size x len_q x d_model], K: [batch_size x len_k x d_model], V: [batch_size x len_v x d_model]
        residual, batch_size = Q, len(Q)

        # proj && split
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_v x d_v]

        # attn_mask: batch_size x len_q x len_k -> [batch_size x n_heads x len_q x len_k]
        if(attn_mask != None):
            attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q x len_k]
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        # context: [batch_size x len_q x n_heads * d_v]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        output = self.linear(context)
        return self.layer_norm(output + residual), attn # output: [batch_size x len_q x d_model]


class MultiHeadCrossAttention(nn.Module):
    def __init__(self, in_dim1=9, in_dim2=9, k_dim=3, v_dim=3, num_heads=3):
        # in_dim1 = k_dim * num_heads
        super().__init__()
        self.num_heads = num_heads
        self.k_dim = k_dim
        self.v_dim = v_dim
        
        self.proj_q1 = nn.Linear(in_dim1, k_dim * num_heads, bias=False)
        self.proj_k2 = nn.Linear(in_dim2, k_dim * num_heads, bias=False)
        self.proj_v2 = nn.Linear(in_dim2, v_dim * num_heads, bias=False)
        self.proj_o = nn.Linear(v_dim * num_heads, in_dim1)
    
    def forward(self, x1, x2, mask=None):
        batch_size, seq_len1, in_dim1 = x1.size()
        seq_len2 = x2.size(1)
        
        # q1(batch_size, num_heads, seq_len1, k_dim)
        q1 = self.proj_q1(x1).view(batch_size, seq_len1, self.num_heads, self.k_dim).transpose(1,2)
        # k2(batch_size, num_heads, k_dim, seq_len2)
        k2 = self.proj_k2(x2).view(batch_size, seq_len2, self.num_heads, self.k_dim).permute(0, 2, 3, 1)
        # v2(batch_size, num_heads, seq_len2, v_dim)
        v2 = self.proj_v2(x2).view(batch_size, seq_len2, self.num_heads, self.v_dim).transpose(1,2)
        
        # attention(batch_size, num_heads, seq_len1, seq_len2)
        attention = torch.matmul(q1, k2) / self.k_dim ** 0.5
        
        if mask is not None:
            mask = mask.reshape(batch_size, self.num_heads, seq_len1, seq_len2)
            attention = attention + mask
        
        attention = F.softmax(attention, dim=1)
        # output(batch_size, num_heads, seq_len1, v_dim)=>(batch_size, seq_len1, num_heads*v_dim)
        output = torch.matmul(attention, v2).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len1, -1)
        # output(batch_size, seq_len1, in_dim1)
        output = self.proj_o(output)
        return output


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model=9, d_ff=2048):
        super(PoswiseFeedForwardNet, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, inputs):
        residual = inputs # inputs : [batch_size, len_q, d_model]
        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
        output = self.conv2(output).transpose(1, 2)
        return self.layer_norm(output + residual)


class QueryPlanEncoderLayer(nn.Module):
    def __init__(self):
        super(QueryPlanEncoderLayer, self).__init__()
        self.multihead = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, query, latency):
        # process Q,K,V
        q_list = []
        l_list = []
        for i in range(0, len(query)):
            q_l = flatten_tree(query[i], [])
            q_list.append(np.array(q_l))
            l_l = flatten_tree(latency[i], [])
            l_l = np.expand_dims(l_l, axis=1)
            l_list.append(np.array(l_l))
        q_list = _pad_and_combine(q_list)
        l_list = _pad_and_combine(l_list)
        L = torch.Tensor(l_list)  # L: [batch_size x len_q x 1]
        Q = torch.Tensor(q_list)
        # Q: [batch_size x len_q x d_model] example: 1x24x9
        # enc_inputs to same Q,K,V
        len_q = Q.size(1)
        pos_encode = PositionalEncoding()
        attn_mask_list = []
        for q in query:
            attn_mask = pos_encode([q], len_q)
            attn_mask_list.append(attn_mask)
        attn_mask = torch.tensor(attn_mask_list) # batch_size x len_q x len_q
        enc_outputs, attn = self.multihead(Q, Q, Q, attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)  # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn, L


class PairEncoderLayer(nn.Module):
    def __init__(self, in_dim1, in_dim2, k_dim, v_dim, n_heads):
        super(PairEncoderLayer, self).__init__()
        self.multihead = MultiHeadCrossAttention(in_dim1, in_dim2, k_dim, v_dim, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(in_dim1)

    def forward(self, plan, knobs, W):
        # process Q,K,V
        # 需要plan和knobs的第三维度对齐len_qk
        K = knobs.repeat_interleave(plan.size(2), dim=2)
        plan = plan.repeat_interleave(knobs.size(2), dim=2)
        enc_outputs = self.multihead(plan, K, W) # batch_size x len_k x len_qk
        enc_outputs = self.pos_ffn(enc_outputs)  # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs


class KnobsEncoderLayer(nn.Module):
    def __init__(self, knobs_dim):
        super(KnobsEncoderLayer, self).__init__()
        self.knobs_dim = knobs_dim
        self.multihead = MultiHeadAttention(d_model=knobs_dim, n_heads=knobs_dim, d_q=1, d_k=1, d_v=1)
        self.pos_ffn = PoswiseFeedForwardNet(d_model=knobs_dim)

    def forward(self, query):
        # process Q,K,V
        # Q: [batch_size x len_q x d_model] example: 1x1x3
        query_list = []
        for q in query:
            query_list.append([q])
        Q = torch.tensor(query_list)
        attn_mask = torch.ones(Q.size(0), Q.size(1), Q.size(1)) # batch_size x len_q x len_q
        enc_outputs, attn = self.multihead(Q, Q, Q, attn_mask=None)
        enc_outputs = self.pos_ffn(enc_outputs)  # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn


class PositionalEncoding(nn.Module):
    def __init__(self):
        super(PositionalEncoding, self).__init__()

    def list_to_tuple(self, lst):
        if isinstance(lst, list):
            return tuple(self.list_to_tuple(item) for item in lst)
        else:
            return lst
    
    def isEmptyNode(self, root):
        return root == [0. for _ in range(9)] or root == [[0. for _ in range(9)]]
    
    def buildGraph(self, root, u, graph, map_):
        if self.isEmptyNode(root):
            return
        if (self.isEmptyNode(root[1]) == False):
            tmp_node = self.list_to_tuple(root[1])
            v = map_[tmp_node]
            graph[u][v] = graph[v][u] = 1
            self.buildGraph(root[1], v, graph, map_)
        if (self.isEmptyNode(root[2]) == False):
            tmp_node = self.list_to_tuple(root[2])
            v = map_[tmp_node]
            graph[u][v] = graph[v][u] = 1
            self.buildGraph(root[2], v, graph, map_)
    
    def getNodeCount(self, tree):
        q = queue.Queue()
        q.put(tree[0])
        n = 0
        while(q.empty() == False):
            size = q.qsize()
            while(size):
                front_node = q.get()
                n = n + 1
                if len(front_node) > 1 and front_node[1] != [0. for _ in range(9)] and front_node[1] != [[0. for _ in range(9)]]:
                    q.put(front_node[1])
                if len(front_node) > 1 and front_node[2] != [0. for _ in range(9)] and front_node[2] != [[0. for _ in range(9)]]:
                    q.put(front_node[2])
                size = size - 1
        return n
    
    def convertTreeToMatrix(self, tree, num):
        # 层序遍历统计树中节点数以及每个节点的编号
        q = queue.Queue()
        q.put(tree[0])
        map_ = {}
        n = 0
        idx = 0
        while(q.empty() == False):
            size = q.qsize()
            while(size):
                front_node = q.get()
                n = n + 1
                tmp_tuple = self.list_to_tuple(front_node)
                map_[tmp_tuple] = idx
                idx = idx + 1
                if len(front_node) > 1 and front_node[1] != [0. for _ in range(9)] and front_node[1] != [[0. for _ in range(9)]]:
                    q.put(front_node[1])
                if len(front_node) > 1 and front_node[2] != [0. for _ in range(9)] and front_node[2] != [[0. for _ in range(9)]]:
                    q.put(front_node[2])
                size = size - 1
        # matrix = [[0] * n for _ in range(n)]
        matrix = [[0] * num for _ in range(num)]
        for i in range(0, n):
            matrix[i][i] = 1
        tmp_root = self.list_to_tuple(tree[0])
        self.buildGraph(tree[0], map_[tmp_root], matrix, map_)
        return matrix

    def forward(self, root, num):
        if len(root) == 0:
            return []
        matrix = self.convertTreeToMatrix(root, num)
        return matrix


class AttrEncoder(nn.Module):
    def __init__(self, plan_in_channel, knobs_dim):
        super(AttrEncoder, self).__init__()
        self.channel = plan_in_channel
        self.knobs_dim = knobs_dim
        self.query_plan_layer = QueryPlanEncoderLayer()
        self.knob_layer = KnobsEncoderLayer(knobs_dim)
        self.pair_layer = PairEncoderLayer(plan_in_channel * knobs_dim, plan_in_channel * knobs_dim, plan_in_channel, plan_in_channel, knobs_dim)
        self.single_split_model = DecisionTreeRegressor(max_depth=1)
        self.l = nn.Linear(plan_in_channel * knobs_dim, 1)

    def train_latency(self, pair_vec, latency):
        assert pair_vec.shape[0] == latency.shape[0]
        # 逐个学习查询
        for i in range(0, len(latency)):
            assert pair_vec[i].shape[0] == latency[i].shape[0]
            self.single_split_model.fit(pair_vec[i].detach().numpy(), latency[i].detach().numpy())
        # 计算shape
        W = []
        for i in range(0, len(latency)):
            shape_values = shap.TreeExplainer(self.single_split_model).shap_values(pair_vec[i].detach().numpy())
            W.append(shape_values)
        W = torch.tensor(W)
        # 只需要取后面N个参数特征
        W = W[:, :, -self.knobs_dim:].float()
        return W

    def forward(self, plans, knobs):
        assert len(plans) == len(knobs)
        plan = []
        latency = []
        for p in plans:
            latency.append(generate_latency(p))
            plan.append(generate_plan_tree(p))
        # query plan self-attention
        query_plan_out, query_plan_attn, L = self.query_plan_layer(plan, latency) # 1x24x9
        # knob self-attention
        knob_out, knob_attn = self.knob_layer(knobs) # 1x1x3
        # cross-attention
        # 模型训练每一类结点:vec(plan,knob)和latency之间的关系
        pair_vec = torch.cat((query_plan_out, knob_out.repeat(1, query_plan_out.size(1), 1)), dim=2) # 1x24x12
        # 计算knob-node矩阵
        W = self.train_latency(pair_vec, L)
        pair_out = self.pair_layer(query_plan_out, knob_out, W)
        pair_out = self.l(pair_out)
        # pair_out = pair_out.mean(dim=1)
        return pair_out


if __name__ == "__main__":
    encoder = AttrEncoder(9, 3)
    sql = "SELECT COUNT(*) FROM cast_info ci WHERE ci.person_id>583692 AND ci.role_id=2"
    plan_analyze, latency = exec_analyze_query(sql, 'imdb', {'enable_optimizer': 1})
    plan_analyze_list = [plan_analyze, plan_analyze]
    vec = encoder([plan_analyze, plan_analyze], [[0., 0.3, 0.4], [0.2, 0.1, 0.]])
    print(vec)