import copy
class Node:
def __init__(self,tab_name):
self.table=tab_name
self.column_w={}
self.adj={}
self.adj_col={}
def add_column(self,col_name,result_size):
self.column_w[col_name]=result_size
class TraditionalJoinOrder:
def __init__(self,join_raltions,metadata):
paths=[]
join_raltions_copy=copy.deepcopy(join_raltions)
used_nodes=[]
for op_id,join_op in enumerate(join_raltions_copy):
left_table,right_table=list(join_op.keys())
left_col,right_col=metadata[left_table]['used_cols'][join_op[left_table]],metadata[right_table]['used_cols'][join_op[right_table]]
if op_id==0:
node1=Node(left_table)
node2=Node(right_table)
node1.adj_col[node2]=left_col
node2.adj_col[node1]=right_col
paths.append((node1,node2,1))
used_nodes.append(node1)
used_nodes.append(node2)
continue
flag=True
left_new=False
if left_table not in [node.table for node in used_nodes] and right_table not in [node.table for node in used_nodes]:
join_raltions_copy.append(join_op)
continue
if left_table not in [node.table for node in used_nodes]:
node1=Node(left_table)
used_nodes.append(node1)
left_new=True
else:
node1=[node for node in used_nodes if node.table==left_table][0]
# for item in paths:
# if (item[0]==node1 or item[1]==node1) and item[2]==1:
# flag=False
# break
for pid,item in enumerate(paths):
if (item[0]==node1 or item[1]==node1):
if item[0]==node1: #Not necessarily hyper, but old node is at the front position of join
flag=False
break
if item[2]==1: #If hyper join appeared before, new node must be shuffle
flag=False
break
if paths[-1][0]!=node1 and paths[-1][1]!=node1:
flag=False
if right_table not in [node.table for node in used_nodes]:
node2=Node(right_table)
used_nodes.append(node2)
else:
node2=[node for node in used_nodes if node.table==right_table][0]
for pid,item in enumerate(paths):
if (item[0]==node2 or item[1]==node2):
if item[0]==node2:
flag=False
break
if item[2]==1:
flag=False
break
if paths[-1][0]!=node2 and paths[-1][1]!=node2:
flag=False
node1.adj_col[node2]=left_col
node2.adj_col[node1]=right_col
if left_new:
node1, node2 = node2, node1
if flag:
paths.append((node1,node2,1))
else:
paths.append((node1,node2,-1))
self.paths=paths
self.print_mst(paths)
def print_mst(self,mst):
join_order_str=''
for idx,tup in enumerate(mst):
join_type='hyper' if tup[2]==1 else 'shuffle'
if tup[2]==1:
join_order_str+=f"[{tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}] "
if tup[2]==-1:
join_order_str+=f"({tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}) "
join_order_str+='\n'
print(join_order_str)
class JoinGraph:
def __init__(self,scan_block_dict, hyper_nodes):
GraphJ=[]
for table_col in scan_block_dict['card'].keys():
tab_name,_=table_col.split('.')
if tab_name not in [node.table for node in GraphJ]:
GraphJ.append(Node(tab_name))
for table_col,result_size in scan_block_dict['card'].items():
tab_name,tab_column=table_col.split('.')
node=GraphJ[[node.table for node in GraphJ].index(tab_name)]
node.add_column(tab_column,result_size)
# adj={node:{} for node in GraphJ}
for item in scan_block_dict['relation']:
left_table_column,right_table_column=item[0],item[1]
left_tab_name,left_col_name=left_table_column.split('.')
right_tab_name,right_col_name=right_table_column.split('.')
left_node=GraphJ[[node.table for node in GraphJ].index(left_tab_name)]
right_node=GraphJ[[node.table for node in GraphJ].index(right_tab_name)]
# left_node.adj[right_node]=[left_col_name,right_col_name,item[2]]
# right_node.adj[left_node]=[right_col_name,left_col_name,item[2]]
left_node.adj[right_node]=self.cpt_edge_weight(left_node,left_col_name,right_node,right_col_name,item[2],hyper_nodes)
left_node.adj_col[right_node]=left_col_name
right_node.adj[left_node]=self.cpt_edge_weight(right_node,right_col_name,left_node,left_col_name,item[2],hyper_nodes)
right_node.adj_col[left_node]=right_col_name
self.GraphJ=GraphJ
def generate_MST(self):
start_edges=[]
max_edge,max_hyper_weight=(),-1
for node in self.GraphJ:
if self.return_degree(node)==1:
start_edges.append((node,list(node.adj.keys())[0]))
for joined_node,wrap in node.adj.items():
if wrap['csy_w']-wrap['chy_w']>max_hyper_weight:
max_hyper_weight=wrap['csy_w']-wrap['chy_w']
max_edge=(node,joined_node)
print(max_hyper_weight)
if max_edge not in start_edges:
start_edges.append(max_edge)
# self.print_graph(start_edge)
min_MST,min_MST_weight=None,float('inf')
for start_edge in start_edges:
paths=[(start_edge[0],start_edge[1],1)]
total_weight=start_edge[0].adj[start_edge[1]]['chy_w']
used_nodes=[start_edge[0],start_edge[1]]
candidated_explore_nodes=[(start_edge[0],1),(start_edge[1],1)]
while len(paths)+1<len(self.GraphJ):
min_edge,min_edge_weight=None,float('inf')
for cur_node,cur_flag in candidated_explore_nodes:
for joined_node,wrap in cur_node.adj.items():
if joined_node not in used_nodes:
new_flag=-cur_flag
if cur_flag==1:
cur_weight=wrap['csy_w']
else:
if paths[-1][1]==cur_node: #Ensure the front node of this hyper join is the end node of the last join
cur_weight=wrap['chy_w']
else:
cur_weight=wrap['csy_w']
new_flag=cur_flag
if cur_weight<min_edge_weight:
min_edge_weight=cur_weight
min_edge=(cur_node,joined_node,new_flag)
paths.append(min_edge)
used_nodes.append(min_edge[1])
if min_edge[2]==1:
candidated_explore_nodes.remove((min_edge[0],-min_edge[2]))
candidated_explore_nodes.append((min_edge[0],min_edge[2]))
candidated_explore_nodes.append((min_edge[1],min_edge[2]))
total_weight+=min_edge_weight
if total_weight<min_MST_weight:
min_MST_weight=total_weight
min_MST=paths
print("The minimum spanning tree is:")
self.print_mst(min_MST)
return min_MST
def return_degree(self,node):
return len(node.adj.keys())
def cpt_edge_weight(self,n_x,col_x,n_y,col_y,join_type,hyper_nodes):
if join_type=='Hyper' and col_x in hyper_nodes[n_x.table] and col_y in hyper_nodes[n_y.table]:
chy=(n_x.column_w[col_x]+n_y.column_w[col_y])*1.2
csy=(n_x.column_w[col_x]+n_y.column_w[col_y]*3)
else:
chy=(n_x.column_w[col_x]+n_y.column_w[col_y]*3)
csy=(n_x.column_w[col_x]+n_y.column_w[col_y]*3)
return {'chy_w':chy,'csy_w':csy}
def print_graph(self,paths):
path_str=''
for table_col in paths:
path_str+='->'+table_col
print(path_str)
def print_mst(self,mst):
join_order_str=''
for idx,tup in enumerate(mst):
join_type='hyper' if tup[2]==1 else 'shuffle'
if tup[2]==1:
join_order_str+=f"[{tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}] "
if tup[2]==-1:
join_order_str+=f"({tup[0].table}.{tup[0].adj_col[tup[1]]}=>{join_type}=>{tup[1].table}.{tup[1].adj_col[tup[0]]}) "
join_order_str+='\n'
print(join_order_str)