import copy class Node: def __init__(self,tab_name): self.table=tab_name self.column_w={} self.adj={} self.adj_col={} def add_column(self,col_name,result_size): self.column_w[col_name]=result_size class TraditionalJoinOrder: def __init__(self,join_raltions,metadata): paths=[] join_raltions_copy=copy.deepcopy(join_raltions) used_nodes=[] for op_id,join_op in enumerate(join_raltions_copy): left_table,right_table=list(join_op.keys()) left_col,right_col=metadata[left_table]['used_cols'][join_op[left_table]],metadata[right_table]['used_cols'][join_op[right_table]] if op_id==0: node1=Node(left_table) node2=Node(right_table) node1.adj_col[node2]=left_col node2.adj_col[node1]=right_col paths.append((node1,node2,1)) used_nodes.append(node1) used_nodes.append(node2) continue flag=True left_new=False if left_table not in [node.table for node in used_nodes] and right_table not in [node.table for node in used_nodes]: join_raltions_copy.append(join_op) continue if left_table not in [node.table for node in used_nodes]: node1=Node(left_table) used_nodes.append(node1) left_new=True else: node1=[node for node in used_nodes if node.table==left_table][0] # for item in paths: # if (item[0]==node1 or item[1]==node1) and item[2]==1: # flag=False # break for pid,item in enumerate(paths): if (item[0]==node1 or item[1]==node1): if item[0]==node1: #Not necessarily hyper, but old node is at the front position of join flag=False break if item[2]==1: #If hyper join appeared before, new node must be shuffle flag=False break if paths[-1][0]!=node1 and paths[-1][1]!=node1: flag=False if right_table not in [node.table for node in used_nodes]: node2=Node(right_table) used_nodes.append(node2) else: node2=[node for node in used_nodes if node.table==right_table][0] for pid,item in enumerate(paths): if (item[0]==node2 or item[1]==node2): if item[0]==node2: flag=False break if item[2]==1: flag=False break if paths[-1][0]!=node2 and paths[-1][1]!=node2: flag=False node1.adj_col[node2]=left_col node2.adj_col[node1]=right_col if left_new: node1, node2 = node2, node1 if flag: paths.append((node1,node2,1)) else: paths.append((node1,node2,-1)) self.paths=paths self.print_mst(paths) def print_mst(self,mst): join_order_str='' for idx,tup in enumerate(mst): join_type='hyper' if tup[2]==1 else 'shuffle' if tup[2]==1: join_order_str+=f"[{tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}] " if tup[2]==-1: join_order_str+=f"({tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}) " join_order_str+='\n' print(join_order_str) class JoinGraph: def __init__(self,scan_block_dict, hyper_nodes): GraphJ=[] for table_col in scan_block_dict['card'].keys(): tab_name,_=table_col.split('.') if tab_name not in [node.table for node in GraphJ]: GraphJ.append(Node(tab_name)) for table_col,result_size in scan_block_dict['card'].items(): tab_name,tab_column=table_col.split('.') node=GraphJ[[node.table for node in GraphJ].index(tab_name)] node.add_column(tab_column,result_size) # adj={node:{} for node in GraphJ} for item in scan_block_dict['relation']: left_table_column,right_table_column=item[0],item[1] left_tab_name,left_col_name=left_table_column.split('.') right_tab_name,right_col_name=right_table_column.split('.') left_node=GraphJ[[node.table for node in GraphJ].index(left_tab_name)] right_node=GraphJ[[node.table for node in GraphJ].index(right_tab_name)] # left_node.adj[right_node]=[left_col_name,right_col_name,item[2]] # right_node.adj[left_node]=[right_col_name,left_col_name,item[2]] left_node.adj[right_node]=self.cpt_edge_weight(left_node,left_col_name,right_node,right_col_name,item[2],hyper_nodes) left_node.adj_col[right_node]=left_col_name right_node.adj[left_node]=self.cpt_edge_weight(right_node,right_col_name,left_node,left_col_name,item[2],hyper_nodes) right_node.adj_col[left_node]=right_col_name self.GraphJ=GraphJ def generate_MST(self): start_edges=[] max_edge,max_hyper_weight=(),-1 for node in self.GraphJ: if self.return_degree(node)==1: start_edges.append((node,list(node.adj.keys())[0])) for joined_node,wrap in node.adj.items(): if wrap['csy_w']-wrap['chy_w']>max_hyper_weight: max_hyper_weight=wrap['csy_w']-wrap['chy_w'] max_edge=(node,joined_node) print(max_hyper_weight) if max_edge not in start_edges: start_edges.append(max_edge) # self.print_graph(start_edge) min_MST,min_MST_weight=None,float('inf') for start_edge in start_edges: paths=[(start_edge[0],start_edge[1],1)] total_weight=start_edge[0].adj[start_edge[1]]['chy_w'] used_nodes=[start_edge[0],start_edge[1]] candidated_explore_nodes=[(start_edge[0],1),(start_edge[1],1)] while len(paths)+1{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}] " if tup[2]==-1: join_order_str+=f"({tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}) " join_order_str+='\n' print(join_order_str)