from partition_algorithm import PartitionAlgorithm import pickle import itertools import time from join_eval import JoinEvaluator import os from join_order import JoinGraph,TraditionalJoinOrder from conf.context import Context import pandas as pd import concurrent.futures from db.conf import table_suffix import argparse from rich.console import Console from rich.logging import RichHandler import logging base_dir=os.path.dirname(os.path.abspath(__file__)) parser = argparse.ArgumentParser() parser.add_argument("--command",type=int, default=0, help="choose the command to run") parser.add_argument("--benchmark",type=str, default='tpcds', help="choose the evaluated benchmark") parser.add_argument("--mode", type=str, default='debug', help="choose the mode to run") parser.add_argument("--init", type=bool, default=False, help="initialize base layouts and join queries") args = parser.parse_args() settings=Context() settings.benchmark=args.benchmark #'tpch' imdb # settings.block_size=200 # 5000, 20000,50000 class Debugger: def __init__(self): self.join_cost_distributions = {} self.query_cost_distributions = {} class TwoLayerTree: def __init__(self,trees): self.candidate_trees=trees class QdTree: def __init__(self,trees): self.candidate_trees=trees # 超图基类 class BaseHypergraph: def __init__(self, benchmark, candidate_nodes): self.benchmark = benchmark self.candidate_nodes = candidate_nodes self.hyper_nodes = {table: [] for table in candidate_nodes} def clear_hyper_nodes(self): self.hyper_nodes = {table: [] for table in self.candidate_nodes} def show_selections(self): for table, columns in self.hyper_nodes.items(): print(f"Hypernode: {table}, Selected Columns: {columns}") # TPC-H 超图类 class HypergraphTPC(BaseHypergraph): def __init__(self): candidate_nodes = { "nation": ["n_nationkey", "n_regionkey"], "region": ["r_regionkey"], "customer": ["c_custkey", "c_nationkey"], "orders": ["o_orderkey", "o_custkey"], "lineitem": ["l_partkey", "l_orderkey", "l_suppkey"], "part": ["p_partkey"], "supplier": ["s_suppkey", "s_nationkey"], "partsupp": ["ps_suppkey", "ps_partkey"] } super().__init__('tpch', candidate_nodes) # JOB 超图类 class JobHypergraph(BaseHypergraph): def __init__(self): candidate_nodes = { "aka_name": ["id", "person_id"], 'aka_title': ["id", "movie_id", "kind_id"], "cast_info": ["id", "person_id", "movie_id", "person_role_id", "role_id"], "char_name": ["id"], "comp_cast_type": ["id"], "company_name": ["id"], "company_type": ["id"], "complete_cast": ["id", "movie_id", "subject_id", "status_id"], "info_type": ["id"], "keyword": ["id"], "kind_type": ["id"], "link_type": ["id"], "movie_companies": ["id", "movie_id", "company_id", "company_type_id"], "movie_info": ["id", "movie_id", "info_type_id"], "movie_info_idx": ["id", "movie_id", "info_type_id"], "movie_keyword": ["id", "movie_id", "keyword_id"], "movie_link": ["link_type_id", "movie_id", "linked_movie_id"], "name": ["id"], "person_info": ["id", "person_id", "info_type_id"], "role_type": ["id"], "title": ["id", "kind_id"], } super().__init__('imdb', candidate_nodes) # TPC-DS 超图类 class HypergraphTPCDS(BaseHypergraph): def __init__(self): candidate_nodes = { # Tables that have joins in queries "catalog_sales": ["cs_bill_customer_sk", "cs_ship_customer_sk", "cs_sold_date_sk", "cs_item_sk", "cs_bill_cdemo_sk"], "customer": ["c_customer_sk", "c_current_addr_sk", "c_current_cdemo_sk"], "customer_address": ["ca_address_sk"], "customer_demographics": ["cd_demo_sk"], "date_dim": ["d_date_sk"], "household_demographics": ["hd_demo_sk"], "inventory": ["inv_date_sk", "inv_item_sk"], "item": ["i_item_sk"], "promotion": ["p_promo_sk"], "store": ["s_store_sk"], "store_returns": ["sr_customer_sk", "sr_item_sk", "sr_ticket_number", "sr_returned_date_sk"], "store_sales": ["ss_sold_date_sk", "ss_item_sk", "ss_customer_sk", "ss_cdemo_sk", "ss_hdemo_sk", "ss_addr_sk", "ss_store_sk", "ss_promo_sk", "ss_ticket_number"], "web_sales": ["ws_bill_customer_sk", "ws_item_sk", "ws_sold_date_sk"], # Tables that don't have joins in current query set, selecting their most likely join keys "call_center": ["cc_call_center_sk"], # Primary key "catalog_page": ["cp_catalog_page_sk"], # Primary key "catalog_returns": ["cr_item_sk", "cr_order_number", "cr_returned_date_sk"], # Composite primary key and common foreign keys "income_band": ["ib_income_band_sk"], # Primary key "reason": ["r_reason_sk"], # Primary key "ship_mode": ["sm_ship_mode_sk"], # Primary key "time_dim": ["t_time_sk"], # Primary key "warehouse": ["w_warehouse_sk"], # Primary key "web_page": ["wp_web_page_sk"], # Primary key "web_returns": ["wr_item_sk", "wr_order_number", "wr_returned_date_sk"], # Composite primary key and common foreign keys "web_site": ["web_site_sk"] # Primary key } super().__init__('tpcds', candidate_nodes) def create_join_trees(hypergraph): tree_dict={} partitioner=PartitionAlgorithm(benchmark=hypergraph.benchmark,block_size=settings.block_size) partitioner.load_join_query() for table, columns in hypergraph.candidate_nodes.items(): tree_dict[table]={} partitioner.table_name=table partitioner.load_data() partitioner.load_query() for column in columns: partitioner.LogicalJoinTree(join_col=column) tree_dict[table][column]=partitioner.partition_tree return TwoLayerTree(tree_dict) def create_color_logger(name='color_logger', level=logging.INFO): console = Console() logging.basicConfig( level=level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(console=console, rich_tracebacks=True)] ) return logging.getLogger(name) def create_qdtrees(hypergraph): tree_dict={} pa=PartitionAlgorithm(benchmark=hypergraph.benchmark,block_size=settings.block_size) pa.load_join_query(join_indeuced='PAW') for table, _ in hypergraph.candidate_nodes.items(): pa.table_name=table pa.load_data() pa.load_query(join_indeuced='PAW') pa.InitializeWithQDT() tree_dict[table]=pa.partition_tree return QdTree(tree_dict) def load_join_queries(hypergraph,is_join_indeuced): pa=PartitionAlgorithm(benchmark=hypergraph.benchmark,block_size=settings.block_size) pa.load_join_query(join_indeuced=is_join_indeuced) return pa.join_queries dubug_logger=Debugger() logger = create_color_logger() def init_partitioner(): if settings.benchmark=='tpch': hypergraph=HypergraphTPC() elif settings.benchmark=='imdb': hypergraph=JobHypergraph() elif settings.benchmark=='tpcds': hypergraph=HypergraphTPCDS() # Create join trees for all tables joinTreer = create_join_trees(hypergraph) join_trees_dir = f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/base_trees' os.makedirs(join_trees_dir, exist_ok=True) with open(f'{join_trees_dir}/join-trees.pkl', 'wb') as f: pickle.dump(joinTreer, f) qdTreer = create_qdtrees(hypergraph) qd_trees_dir = f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/base_trees' os.makedirs(qd_trees_dir, exist_ok=True) with open(f'{qd_trees_dir}/qd-trees.pkl', 'wb') as f: pickle.dump(qdTreer, f) def init_workload(): if settings.benchmark=='tpch': hypergraph=HypergraphTPC() elif settings.benchmark=='imdb': hypergraph=JobHypergraph() elif settings.benchmark=='tpcds': hypergraph=HypergraphTPCDS() queries_base_dir = f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/used_queries' os.makedirs(queries_base_dir, exist_ok=True) paw_queries = load_join_queries(hypergraph, is_join_indeuced='PAW') paw_queries_path = os.path.join(queries_base_dir, 'paw-queries.pkl') with open(paw_queries_path, 'wb') as f: pickle.dump(paw_queries, f) mto_queries = load_join_queries(hypergraph, is_join_indeuced='MTO') mto_queries_path = os.path.join(queries_base_dir, 'mto-queries.pkl') with open(mto_queries_path, 'wb') as f: pickle.dump(mto_queries, f) pac_queries = load_join_queries(hypergraph, is_join_indeuced='PAC') pac_queries_path = os.path.join(queries_base_dir, 'pac-queries.pkl') with open(pac_queries_path, 'wb') as f: pickle.dump(pac_queries, f) # ~~~~~~~~loading global variables~~~~~~~~ if args.init: init_workload() init_partitioner() exit(0) # if args.mode=='debug': # init_workload() class PartitionContext: def __init__(self): self.joinTreer = None self.qdTreer = None self.paw_queries = None self.mto_queries = None self.pac_queries = None self.metadata = None self.table_metadata = None context=PartitionContext() def load_tree_context(): context.joinTreer=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/base_trees/join-trees.pkl','rb')) context.qdTreer=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/base_trees/qd-trees.pkl','rb')) context.paw_queries=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/used_queries/paw-queries.pkl','rb')) context.mto_queries=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/used_queries/mto-queries.pkl','rb')) context.pac_queries=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/used_queries/pac-queries.pkl','rb')) # ~~~~~~~~Load partial table metadata information~~~~~~~~ context.metadata={} for table,tree in context.joinTreer.candidate_trees.items(): sample_tree=list(tree.values())[0] context.metadata[table]={'row_count':sample_tree.pt_root.node_size,'used_cols':sample_tree.used_columns} context.table_metadata=pickle.load(open(f'{base_dir}/../dataset/{settings.benchmark}/metadata.pkl','rb')) #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def estimate_join_cost(hypergraph,current_table): shuffle_times,hyper_times={},{} join_cost=0 for q_id, join_query in enumerate(context.mto_queries): for join_op in join_query['join_relations']: left_table,left_col_idx=list(join_op.items())[0] right_table,right_col_idx=list(join_op.items())[1] left_col,right_col=context.metadata[left_table]["used_cols"][left_col_idx],context.metadata[right_table]["used_cols"][right_col_idx] if left_table!=current_table or right_table!=current_table: continue left_query,right_query=join_query['vectors'][left_table],join_query['vectors'][right_table] is_hyper=True if left_col in hypergraph.hyper_nodes[left_table]: left_tree=context.joinTreer.candidate_trees[left_table][left_col] else: left_tree=list(context.joinTreer.candidate_trees[left_table].values())[0] is_hyper=False if right_col in hypergraph.hyper_nodes[right_table]: right_tree=context.joinTreer.candidate_trees[right_table][right_col] else: right_tree=list(context.joinTreer.candidate_trees[right_table].values())[0] is_hyper=False left_cost=sum([left_tree.nid_node_dict[bid].node_size for bid in left_tree.query_single(left_query)]) right_cost=sum([right_tree.nid_node_dict[bid].node_size for bid in right_tree.query_single(right_query)]) join_cost+=left_cost+right_cost if is_hyper else (left_cost+3*right_cost) return join_cost def _should_skip_query(q_id): # skip query 7 in imdb and query 6 in tpch (invalid ids) return (settings.benchmark == 'imdb' and q_id == 7) or \ (settings.benchmark == 'tpch' and q_id == 6) or \ (settings.benchmark == 'tpcds' and q_id == -1) def processing_join_workload(hypergraph, group_type, join_strategy='traditional', tree_type='paw'): """ Process join workload and calculate the positions and total cost of hyper edges and shuffle edges in the current hypergraph Args: hypergraph: Hypergraph structure group_type: Group type (0 or 1) join_strategy: Join strategy ('traditional' or 'prim') tree_type: Tree type ('paw', 'mto' or 'pac') Returns: final_join_result: Table-level join result statistics query_cost_dict: Query-level cost statistics query_ratio_dict: Query-level ratio statistics opt_time: Optimization time statistics """ # Initialize statistics trackers shuffle_times, hyper_times = {}, {} join_cost_dict = {} cost_log, byte_log = {}, {} solved_tables = {} shuffle_group_time, hyper_group_time = [], [] tot_order_time = 0 # Select appropriate query set input_queries = context.paw_queries if tree_type == 'paw' else \ context.mto_queries if tree_type == 'mto' else context.pac_queries # ========== Process each query ========== for q_id, join_query in enumerate(input_queries): print('processing query:', q_id) if _should_skip_query(q_id) or join_query['vectors'] == {}: if join_query['vectors'] == {}: print(f'empty query : {settings.benchmark} {q_id}') continue query_cost_log, query_bytes_log = [], [] solved_tables[q_id] = set() # ---------- Process queries without join relations ---------- if not join_query['join_relations']: print('no join operation') for table, query_vector in join_query['vectors'].items(): solved_tables[q_id].add(table) # Select appropriate tree if tree_type == 'paw': tree = context.qdTreer.candidate_trees[table] else: tree = list(context.joinTreer.candidate_trees[table].values())[0] # Calculate scan cost b_ids = tree.query_single(query_vector) scan_cost = sum([tree.nid_node_dict[b_id].node_size for b_id in b_ids]) join_cost_dict[table] = join_cost_dict.get(table, 0) + scan_cost query_cost_log.append(scan_cost) query_bytes_log.append(context.table_metadata[table]['read_line'] * scan_cost) logger.info("Just scan table: %s, cost: %s", table, scan_cost) cost_log[q_id] = query_cost_log byte_log[q_id] = query_bytes_log continue # ---------- Process queries with join relations ---------- # 1. Determine join order start1_time = time.time() join_path = [] if join_strategy == 'traditional': join_path = TraditionalJoinOrder(join_query['join_relations'], context.metadata).paths else: # Build scan block dictionary for Prim algorithm scan_block_dict = {'card': {}, 'relation': []} for join_op in join_query['join_relations']: left_table, left_col_idx = list(join_op.items())[0] right_table, right_col_idx = list(join_op.items())[1] left_col = context.metadata[left_table]["used_cols"][left_col_idx] right_col = context.metadata[right_table]["used_cols"][right_col_idx] left_query = join_query['vectors'][left_table] right_query = join_query['vectors'][right_table] is_hyper = True # Process PAW tree type if tree_type == 'paw': if f'{left_table}.{left_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{left_table}.{left_col}'] = len(context.qdTreer.candidate_trees[left_table].query_single(left_query)) if f'{right_table}.{right_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{right_table}.{right_col}'] = len(context.qdTreer.candidate_trees[right_table].query_single(right_query)) is_hyper = False # Process MTO or PAC tree type else: # Process left table tree if left_col in hypergraph.hyper_nodes[left_table]: tree = context.joinTreer.candidate_trees[left_table][left_col] else: tree = list(context.joinTreer.candidate_trees[left_table].values())[0] is_hyper = False if f'{left_table}.{left_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{left_table}.{left_col}'] = len(tree.query_single(left_query)) # Process right table tree if right_col in hypergraph.hyper_nodes[right_table]: tree = context.joinTreer.candidate_trees[right_table][right_col] else: tree = list(context.joinTreer.candidate_trees[right_table].values())[0] is_hyper = False if f'{right_table}.{right_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{right_table}.{right_col}'] = len(tree.query_single(right_query)) scan_block_dict["relation"].append([ f'{left_table}.{left_col}', f'{right_table}.{right_col}', 'Hyper' if is_hyper else 'Shuffle' ]) # Generate MST using JoinGraph jg = JoinGraph(scan_block_dict,hypergraph.hyper_nodes) join_path = jg.generate_MST() tot_order_time += time.time() - start1_time # Calculate join reorder time # 2. Process join path temp_table_dfs = [] temp_joined_bytes = [] temp_shuffle_ops = [] # Iterate through each item in join path for order_id, item in enumerate(join_path): join_queryset, joined_cols, joined_trees, joined_tables = [], [], [], [] is_hyper = True if item[2] == 1 else False is_real_hyper = True join_ops = [ (item[0].table, item[0].adj_col[item[1]]), (item[1].table, item[1].adj_col[item[0]]) ] # Collect join operation information for join_table, join_col in join_ops: join_queryset.append(join_query['vectors'][join_table]) join_col_idx = context.metadata[join_table]["used_cols"].index(join_col) joined_cols.append(join_col_idx) # Select appropriate tree if tree_type == 'paw': tree = context.qdTreer.candidate_trees[join_table] is_real_hyper = False else: if join_col in hypergraph.hyper_nodes[join_table]: tree = context.joinTreer.candidate_trees[join_table][join_col] else: tree = list(context.joinTreer.candidate_trees[join_table].values())[0] is_real_hyper = False joined_trees.append(tree) joined_tables.append(join_table) # Process shuffle join if not is_hyper: cur_idx = 0 if order_id == 0 else 1 # Current is last path or next is not hyper if order_id == len(join_path) - 1 or join_path[order_id+1][2] == -1: tree = joined_trees[cur_idx] cur_table = joined_tables[cur_idx] solved_tables[q_id].add(cur_table) # Get table data b_ids = tree.query_single(join_queryset[cur_idx]) table_dataset = [] for b_id in b_ids: table_dataset += list(tree.nid_node_dict[b_id].dataset) # Set column names if hypergraph.benchmark == 'imdb': used_columns = [table_suffix[hypergraph.benchmark][cur_table]+"_"+col for col in tree.used_columns] else: used_columns = tree.used_columns temp_table_dfs.append(pd.DataFrame(table_dataset, columns=used_columns)) temp_shuffle_ops.append(( item[0].table, item[0].adj_col[item[1]], item[1].table, item[1].adj_col[item[0]] )) temp_joined_bytes.append(context.table_metadata[item[1].table]['read_line']) continue # Process hyper join # Ensure smaller table is on the left (optimization) if len(joined_trees[0].query_single(join_queryset[0])) > len(joined_trees[1].query_single(join_queryset[1])): join_queryset[0], join_queryset[1] = join_queryset[1], join_queryset[0] joined_cols[0], joined_cols[1] = joined_cols[1], joined_cols[0] joined_trees[0], joined_trees[1] = joined_trees[1], joined_trees[0] joined_tables[0], joined_tables[1] = joined_tables[1], joined_tables[0] left_table, right_table = joined_tables[0], joined_tables[1] # Update statistics if is_real_hyper: hyper_times[left_table] = hyper_times.get(left_table, 0) + 1 hyper_times[right_table] = hyper_times.get(right_table, 0) + 1 else: shuffle_times[left_table] = shuffle_times.get(left_table, 0) + 1 shuffle_times[right_table] = shuffle_times.get(right_table, 0) + 1 solved_tables[q_id].add(left_table) solved_tables[q_id].add(right_table) # Execute join evaluation join_eval = JoinEvaluator( join_queryset, joined_cols, joined_trees, joined_tables, settings.block_size, context.table_metadata, benchmark=hypergraph.benchmark ) hyper_shuffle_cost, hyper_shuffle_read_bytes, temp_joined_df, group_time = join_eval.compute_total_shuffle_hyper_cost(group_type, is_real_hyper) # Update statistics join_cost_dict[left_table] = join_cost_dict.get(left_table, 0) + hyper_shuffle_cost // 2 join_cost_dict[right_table] = join_cost_dict.get(right_table, 0) + hyper_shuffle_cost // 2 if is_real_hyper: hyper_group_time.append(group_time) else: shuffle_group_time.append(group_time) query_cost_log.append(hyper_shuffle_cost) query_bytes_log.append(hyper_shuffle_read_bytes) logger.info("Join %s and %s, %s hyper cost: %s", left_table, right_table, '' if is_real_hyper else 'fake',hyper_shuffle_cost) temp_table_dfs.append(temp_joined_df) temp_joined_bytes.append(context.table_metadata[left_table]['read_line'] +context.table_metadata[right_table]['read_line']) # 3. Process multi-table join if len(temp_table_dfs) > 1: query_cost_log.append(-1) # Mark start of multi-table join shuffle_cost = 0 shuffle_bytes = 0 last_temp_dfs = temp_table_dfs[0] last_temp_line = temp_joined_bytes[0] # Define hash join function def pandas_hash_join(df_A, join_col_A, df_B, join_col_B): if join_col_B in df_A.columns: return df_A if df_A.shape[0] > df_B.shape[0]: df_A, df_B = df_B, df_A join_col_A, join_col_B = join_col_B, join_col_A df_B = df_B.drop_duplicates(subset=[join_col_B], keep='first') merged_df = df_A.merge(df_B, how='inner', left_on=join_col_A, right_on=join_col_B) return merged_df # Execute multi-table join for i in range(1, len(temp_table_dfs)): join_table1, join_col1, join_table2, join_col2 = temp_shuffle_ops[i-1] sample_rate1,sample_rate2 = 1, 1 if join_col1 in context.joinTreer.candidate_trees[join_table1] and hasattr(context.joinTreer.candidate_trees[join_table1][join_col1], 'sample_rate'): sample_rate1=context.joinTreer.candidate_trees[join_table1][join_col1].sample_rate if join_col2 in context.joinTreer.candidate_trees[join_table2] and hasattr(context.joinTreer.candidate_trees[join_table2][join_col2], 'sample_rate'): sample_rate2=context.joinTreer.candidate_trees[join_table2][join_col2].sample_rate data_sample_rate=min(sample_rate1,sample_rate2) # Calculate cost shuffle_cost += int((temp_table_dfs[i].shape[0] * 3 + last_temp_dfs.shape[0]* 1)/data_sample_rate) shuffle_bytes += int((temp_table_dfs[i].shape[0] * 3 * last_temp_line + last_temp_dfs.shape[0]* 1 * temp_joined_bytes[i])/data_sample_rate) if hypergraph.benchmark == 'imdb': join_col1 = table_suffix[hypergraph.benchmark][join_table1] + "_" + join_col1 join_col2 = table_suffix[hypergraph.benchmark][join_table2] + "_" + join_col2 # Execute join last_temp_dfs = pandas_hash_join(last_temp_dfs, join_col1, temp_table_dfs[i], join_col2) last_temp_line = temp_joined_bytes[i] + last_temp_line # Update statistics shuffle_times[join_table1] = shuffle_times.get(join_table1, 0) + 1 shuffle_times[join_table2] = shuffle_times.get(join_table2, 0) + 1 join_cost_dict[join_table1] = join_cost_dict.get(join_table1, 0) + shuffle_cost // 2 join_cost_dict[join_table2] = join_cost_dict.get(join_table2, 0) + shuffle_cost // 2 logger.info("Join %s and %s, shuffle cost: %s", join_table1, join_table2, shuffle_cost) query_cost_log.append(shuffle_cost) query_bytes_log.append(shuffle_bytes) cost_log[q_id] = query_cost_log byte_log[q_id] = query_bytes_log # ========== Summarize results ========== # 1. Calculate query costs and ratios query_cost_dict = {} query_ratio_dict = {} for qid in cost_log.keys(): query_cost_dict[qid] = sum(list(cost_log[qid])) query_ratio_dict[qid] = sum(list(byte_log[qid])) / sum( [sum(context.table_metadata[table]['width'].values()) * context.table_metadata[table]['rows'] for table in solved_tables[qid]] ) print(f"Query {qid} cost log: {cost_log[qid]} | byte log: {byte_log[qid]}") print(f"Query ratio log: {query_ratio_dict}") if args.mode=='debug': print(f"Join cost distributions: {join_cost_dict}") if settings.benchmark not in dubug_logger.join_cost_distributions: dubug_logger.join_cost_distributions[settings.benchmark]={} dubug_logger.join_cost_distributions[settings.benchmark][tree_type] = join_cost_dict if settings.benchmark not in dubug_logger.query_cost_distributions: dubug_logger.query_cost_distributions[settings.benchmark]={} dubug_logger.query_cost_distributions[settings.benchmark][tree_type] = {} for qid in cost_log.keys(): # Find index of -1 in cost_log, join elements before it as string, join elements after it as string, separated by space if -1 in cost_log[qid]: idx = cost_log[qid].index(-1) dubug_logger.query_cost_distributions[settings.benchmark][tree_type][qid]=[round(query_ratio_dict[qid],4),' '.join([str(i) for i in cost_log[qid][:idx]]),''.join([str(i) for i in cost_log[qid][idx+1:]])] # 2. Summarize table-level results final_join_result = {} for table in join_cost_dict.keys(): final_join_result[table] = { 'shuffle times': shuffle_times.get(table, 0), 'hyper times': hyper_times.get(table, 0), 'join cost': join_cost_dict[table] } # 3. Calculate averages query_cost_dict['avg'] = sum(list(query_cost_dict.values())) / len(query_cost_dict.keys()) query_ratio_dict['avg'] = sum(list(query_ratio_dict.values())) / len(query_ratio_dict.keys()) table_num=len(final_join_result.keys()) shuffle_avg = sum(list(shuffle_times.values())) / table_num if shuffle_times else 0 hyper_avg = sum(list(hyper_times.values())) / table_num if hyper_times else 0 cost_avg = sum(list(join_cost_dict.values())) / table_num if join_cost_dict else 0 final_join_result['avg'] = { 'shuffle times': shuffle_avg, 'hyper times': hyper_avg, 'join cost': cost_avg } # 4. Calculate optimization time opt_time = [ tot_order_time, sum(hyper_group_time) / len(hyper_group_time) if len(hyper_group_time) > 0 else 0, sum(shuffle_group_time) / len(shuffle_group_time) if len(shuffle_group_time) > 0 else 0 ] return final_join_result, query_cost_dict, query_ratio_dict, opt_time def processing_join_workload_old2(hypergraph,group_type,join_strategy='traditional',tree_type='paw'): shuffle_times,hyper_times={},{} join_cost_dict={} cost_log={} byte_log={} solved_tables={} input_queries=context.paw_queries if tree_type=='paw' else context.mto_queries if tree_type=='mto' else context.pac_queries shuffle_group_time,hyper_group_time=[],[] tot_order_time=0 for q_id,join_query in enumerate(input_queries): print('processing query:',q_id) if _should_skip_query(q_id): continue if join_query['vectors']=={}: print(f'empty query : {settings.benchmark} {q_id}') continue query_cost_log,query_bytes_log=list(),list() solved_tables[q_id]=set() if join_query['join_relations']: start1_time=time.time() if join_strategy=='traditional': join_path=TraditionalJoinOrder(join_query['join_relations'],context.metadata).paths else: scan_block_dict={'card':{},'relation':[]} for join_op in join_query['join_relations']: left_table,left_col_idx=list(join_op.items())[0] right_table,right_col_idx=list(join_op.items())[1] left_col,right_col=context.metadata[left_table]["used_cols"][left_col_idx],context.metadata[right_table]["used_cols"][right_col_idx] left_query=join_query['vectors'][left_table] right_query=join_query['vectors'][right_table] is_hyper=True if tree_type=='paw': if f'{left_table}.{left_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{left_table}.{left_col}']=len(context.qdTreer.candidate_trees[left_table].query_single(left_query)) if f'{right_table}.{right_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{right_table}.{right_col}']=len(context.qdTreer.candidate_trees[right_table].query_single(right_query)) is_hyper=False else: if left_col in hypergraph.hyper_nodes[left_table]: tree=context.joinTreer.candidate_trees[left_table][left_col] else: tree=list(context.joinTreer.candidate_trees[left_table].values())[0] is_hyper=False if f'{left_table}.{left_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{left_table}.{left_col}']=len(tree.query_single(left_query)) if right_col in hypergraph.hyper_nodes[right_table]: tree=context.joinTreer.candidate_trees[right_table][right_col] else: tree=list(context.joinTreer.candidate_trees[right_table].values())[0] is_hyper=False if f'{right_table}.{right_col}' not in scan_block_dict["card"]: scan_block_dict["card"][f'{right_table}.{right_col}']=len(tree.query_single(right_query)) scan_block_dict["relation"].append([f'{left_table}.{left_col}',f'{right_table}.{right_col}','Hyper' if is_hyper else 'Shuffle']) jg=JoinGraph(scan_block_dict) join_path=jg.generate_MST() tot_order_time+=time.time()-start1_time # 计算join reorder的时间 temp_table_dfs=[] temp_joined_bytes=[] temp_shuffle_ops=[] for order_id,item in enumerate(join_path): join_queryset,joined_cols,joined_trees,joined_tables=[],[],[],[] is_hyper=True if item[2]==1 else False is_real_hyper=True join_ops=[(item[0].table,item[0].adj_col[item[1]]),(item[1].table,item[1].adj_col[item[0]])] for join_table,join_col in join_ops: join_queryset.append(join_query['vectors'][join_table]) join_col_idx=context.metadata[join_table]["used_cols"].index(join_col) joined_cols.append(join_col_idx) if tree_type=='paw': tree=context.qdTreer.candidate_trees[join_table] is_real_hyper=False else: if join_col in hypergraph.hyper_nodes[join_table]: tree=context.joinTreer.candidate_trees[join_table][join_col] else: tree=list(context.joinTreer.candidate_trees[join_table].values())[0] is_real_hyper=False joined_trees.append(tree) joined_tables.append(join_table) if not is_hyper: tree=None cur_idx=0 if order_id==0 else 1 if order_id==len(join_path)-1 or join_path[order_id+1][2]==-1: tree=joined_trees[cur_idx] cur_table=joined_tables[cur_idx] solved_tables[q_id].add(cur_table) b_ids=tree.query_single(join_queryset[cur_idx]) table_dataset=[] for b_id in b_ids: table_dataset+=list(tree.nid_node_dict[b_id].dataset) if hypergraph.benchmark=='imdb': used_columns=[table_suffix[hypergraph.benchmark][cur_table]+"_"+col for col in tree.used_columns] else: used_columns=tree.used_columns temp_table_dfs.append(pd.DataFrame(table_dataset,columns=used_columns)) temp_shuffle_ops.append((item[0].table,item[0].adj_col[item[1]],item[1].table,item[1].adj_col[item[0]])) temp_joined_bytes.append(context.table_metadata[item[1].table]['read_line']) continue else: if len(joined_trees[0].query_single(join_queryset[0]))>len(joined_trees[1].query_single(join_queryset[1])): join_queryset[0],join_queryset[1]=join_queryset[1],join_queryset[0] joined_cols[0],joined_cols[1]=joined_cols[1],joined_cols[0] joined_trees[0],joined_trees[1]=joined_trees[1],joined_trees[0] joined_tables[0],joined_tables[1]=joined_tables[1],joined_tables[0] left_table,right_table=joined_tables[0],joined_tables[1] if is_real_hyper: hyper_times[left_table]=hyper_times.get(left_table,0)+1 hyper_times[right_table]=hyper_times.get(left_table,0)+1 else: shuffle_times[left_table]=shuffle_times.get(left_table,0)+1 shuffle_times[right_table]=shuffle_times.get(left_table,0)+1 solved_tables[q_id].add(left_table) solved_tables[q_id].add(right_table) join_eval = JoinEvaluator(join_queryset,joined_cols,joined_trees,joined_tables,settings.block_size,context.table_metadata,benchmark=hypergraph.benchmark) hyper_shuffle_cost,hyper_shuffle_read_bytes, temp_joined_df, group_time=join_eval.compute_total_shuffle_hyper_cost(group_type,is_real_hyper) join_cost_dict[left_table]=join_cost_dict.get(left_table,0)+hyper_shuffle_cost//2 join_cost_dict[right_table]=join_cost_dict.get(right_table,0)+hyper_shuffle_cost//2 if is_real_hyper: hyper_group_time.append(group_time) else: shuffle_group_time.append(group_time) query_cost_log.append(hyper_shuffle_cost) query_bytes_log.append(hyper_shuffle_read_bytes) temp_table_dfs.append(temp_joined_df) temp_joined_bytes.append(context.table_metadata[left_table]['read_line']+context.table_metadata[right_table]['read_line']) if len(temp_table_dfs)>1: query_cost_log.append(-1) shuffle_cost=0 shuffle_bytes=0 last_temp_dfs=temp_table_dfs[0] last_temp_line=temp_joined_bytes[0] def pandas_hash_join(df_A, join_col_A, df_B, join_col_B): if join_col_B in df_A.columns: return df_A if df_A.shape[0] > df_B.shape[0]: df_A, df_B = df_B, df_A join_col_A, join_col_B = join_col_B, join_col_A df_B.drop_duplicates(subset=[join_col_B], keep='first') merged_df = df_A.merge(df_B, how='inner', left_on=join_col_A, right_on=join_col_B) return merged_df for i in range(1,len(temp_table_dfs)): shuffle_cost+=(temp_table_dfs[i].shape[0]*3+last_temp_dfs.shape[0]) shuffle_bytes+=(temp_table_dfs[i].shape[0]*3*last_temp_line+last_temp_dfs.shape[0]*temp_joined_bytes[i]) join_table1,join_col1,join_table2,join_col2=temp_shuffle_ops[i-1] if hypergraph.benchmark=='imdb': join_col1=table_suffix[hypergraph.benchmark][join_table1]+"_"+join_col1 join_col2=table_suffix[hypergraph.benchmark][join_table2]+"_"+join_col2 last_temp_dfs=pandas_hash_join(last_temp_dfs,join_col1,temp_table_dfs[i],join_col2) last_temp_line=temp_joined_bytes[i]+last_temp_line shuffle_times[join_table1]=shuffle_times.get(join_table1,0)+1 shuffle_times[join_table2]=shuffle_times.get(join_table2,0)+1 join_cost_dict[join_table1]=join_cost_dict.get(join_table1,0)+shuffle_cost//2 join_cost_dict[join_table2]=join_cost_dict.get(join_table2,0)+shuffle_cost//2 query_cost_log.append(shuffle_cost) query_bytes_log.append(shuffle_bytes) else: print('no join operation') for table, query_vector in join_query['vectors'].items(): solved_tables[q_id].add(table) if tree_type=='paw': tree=context.qdTreer.candidate_trees[table] else: tree=list(context.joinTreer.candidate_trees[table].values())[0] b_ids=tree.query_single(query_vector) scan_cost=sum([tree.nid_node_dict[b_id].node_size for b_id in b_ids]) join_cost_dict[table]=join_cost_dict.get(table,0)+scan_cost query_cost_log.append(scan_cost) query_bytes_log.append(context.table_metadata[table]['read_line']*scan_cost) cost_log[q_id]=query_cost_log byte_log[q_id]=query_bytes_log query_cost_dict={} query_ratio_dict={} for qid in cost_log.keys(): query_cost_dict[qid]=sum(list(cost_log[qid])) query_ratio_dict[qid]=sum(list(byte_log[qid]))/sum([sum(context.table_metadata[table]['width'].values())*context.table_metadata[table]['rows'] for table in solved_tables[qid]]) print(f"Query {qid} cost log: {cost_log[qid]} | ratio log: {byte_log[qid]}") print(f"Query ratio log: {query_ratio_dict}") final_join_result={} for table in join_cost_dict.keys(): final_join_result[table]={'shuffle times':shuffle_times.get(table,0),'hyper times':hyper_times.get(table,0),'join cost':join_cost_dict[table]} query_cost_dict['avg']=sum(list(query_cost_dict.values()))/len(query_cost_dict.keys()) query_ratio_dict['avg']=sum(list(query_ratio_dict.values()))/len(query_ratio_dict.keys()) final_join_result['avg']={'shuffle times':sum(list(shuffle_times.values()))/len(shuffle_times.keys()) if shuffle_times else 0,'hyper times':sum(list(hyper_times.values()))/len(hyper_times.keys()) if hyper_times else 0,'join cost':sum(list(join_cost_dict.values()))/len(join_cost_dict.keys()) if join_cost_dict else 0} opt_time=[tot_order_time,sum(hyper_group_time)/len(hyper_group_time) if len(hyper_group_time)>0 else 0,sum(shuffle_group_time)/len(shuffle_group_time) if len(shuffle_group_time)>0 else 0] return final_join_result,query_cost_dict,query_ratio_dict,opt_time def select_columns_by_PAW(group_type=0, strategy='traditional'): # Build hypergraph structure hypergraph=None if settings.benchmark=='tpch': hypergraph = HypergraphTPC() elif settings.benchmark=='imdb': hypergraph = JobHypergraph() elif settings.benchmark=='tpcds': hypergraph = HypergraphTPCDS() total_opt_time=[] tree_opt_time=0 for table in hypergraph.hyper_nodes: tree_opt_time+=context.qdTreer.candidate_trees[table].build_time total_opt_time.append(tree_opt_time) final_result, query_costs, query_ratio, group_opt_time = processing_join_workload(hypergraph, group_type, strategy, tree_type='paw') # total_opt_time+=group_opt_time print("Final Column Selection Cost of PAW is: ",sum(list(query_costs.values()))/len(query_costs.keys())) print("Average Data Cost Ratio of PAW is: ",sum(list(query_ratio.values()))/len(query_ratio.keys())) print("Total Opt Time of PAW is: ",total_opt_time) hypergraph.show_selections() hypergraph_save_path=f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/join_key_selection/paw-hgraph.pkl' os.makedirs(os.path.dirname(hypergraph_save_path), exist_ok=True) with open(hypergraph_save_path,'wb') as f: pickle.dump(hypergraph,f) return final_result, query_costs,query_ratio, sum(total_opt_time) def select_columns_by_AD_MTO(group_type=0, strategy='traditional'): hypergraph=None if settings.benchmark=='tpch': hypergraph = HypergraphTPC() elif settings.benchmark=='imdb': hypergraph = JobHypergraph() elif settings.benchmark=='tpcds': hypergraph = HypergraphTPCDS() total_opt_time=[] tree_opt_time=0 for table in hypergraph.hyper_nodes: col=context.metadata[table]['used_cols'][0] if col in hypergraph.candidate_nodes[table]: hypergraph.hyper_nodes[table].append(col) tree_opt_time+=context.joinTreer.candidate_trees[table][col].build_time total_opt_time.append(tree_opt_time) final_result, query_costs, query_ratio, group_opt_time = processing_join_workload(hypergraph, group_type, strategy, tree_type='mto') # total_opt_time+=group_opt_time print("Final Column Selection Cost of AD-MTO is: ",sum(list(query_costs.values()))/len(query_costs.keys())) print("Average Data Cost Ratio of AD-MTO is: ",sum(list(query_ratio.values()))/len(query_ratio.keys())) print("Total Opt Time of AD-MTO is: ",total_opt_time) hypergraph.show_selections() hypergraph_save_path=f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/join_key_selection/ad-mto-hgraph.pkl' os.makedirs(os.path.dirname(hypergraph_save_path), exist_ok=True) with open(hypergraph_save_path,'wb') as f: pickle.dump(hypergraph,f) return final_result, query_costs, query_ratio, sum(total_opt_time) def select_columns_by_JT(group_type,strategy='traditional',disabled_joinkey_selection=False): adj_matrix = {} for query in context.pac_queries: for join_op in query["join_relations"]: join_relations=set() for table,col in join_op.items(): join_relations.add(f'{table}.{context.metadata[table]["used_cols"][col]}') for key1, key2 in itertools.combinations(join_relations, 2): if key1 not in adj_matrix: adj_matrix[key1] = {} if key2 not in adj_matrix: adj_matrix[key2] = {} adj_matrix[key1][key2] = adj_matrix[key1].get(key2, 0) + 1 adj_matrix[key2][key1] = adj_matrix[key2].get(key1, 0) + 1 hypergraph=None if settings.benchmark=='tpch': hypergraph = HypergraphTPC() elif settings.benchmark=='imdb': hypergraph = JobHypergraph() elif settings.benchmark=='tpcds': hypergraph = HypergraphTPCDS() total_opt_time=[] tree_opt_time=0 if disabled_joinkey_selection: hypergraph.clear_hyper_nodes() for table in hypergraph.hyper_nodes: col=context.metadata[table]['used_cols'][0] if col in hypergraph.candidate_nodes[table]: hypergraph.hyper_nodes[table].append(col) tree_opt_time+=context.joinTreer.candidate_trees[table][col].build_time else: for table in hypergraph.candidate_nodes: col_freq_dict={} for source_col in hypergraph.candidate_nodes[table]: if f'{table}.{source_col}' in adj_matrix: for key,freq in adj_matrix[f'{table}.{source_col}'].items(): targe_table=key.split('.')[0] tot_freq=freq*context.table_metadata[targe_table]['rows'] else: tot_freq=0 col_freq_dict[source_col]=tot_freq col_list=sorted(col_freq_dict.items(), key=lambda x: x[1], reverse=True) for i in range(1,settings.replication_factor+1): if i > len(col_list): break hypergraph.hyper_nodes[table].append(col_list[i-1][0]) for table in hypergraph.hyper_nodes: for col in hypergraph.hyper_nodes[table]: tree_opt_time+=context.joinTreer.candidate_trees[table][col].build_time total_opt_time.append(tree_opt_time) final_result, query_costs, query_ratio, group_opt_time = processing_join_workload(hypergraph, group_type, strategy, tree_type='pac') # total_opt_time+=group_opt_time print("Final Column Selection Cost of PAC-Tree is: ",sum(list(query_costs.values()))/len(query_costs.keys())) print("Average Data Cost Ratio of PAC-Tree is: ",sum(list(query_ratio.values()))/len(query_ratio.keys())) print("Total Opt Time (tree,order,group) of PAC-Tree is: ",total_opt_time) hypergraph.show_selections() hypergraph_save_path=f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/join_key_selection/jt-hgraph.pkl' os.makedirs(os.path.dirname(hypergraph_save_path), exist_ok=True) with open(hypergraph_save_path,'wb') as f: pickle.dump(hypergraph,f) return final_result, query_costs, query_ratio, sum(total_opt_time) def save_experiment_result(experiment_result,disabled_prim_reorder): import pandas as pd import os base_dir = os.path.dirname(os.path.abspath(__file__)) result_dir = f"{base_dir}/../experiment/result/sf={settings.scale_factor}/bs={settings.block_size}/rep={settings.replication_factor}/" os.makedirs(result_dir, exist_ok=True) query_list = [] table_list = [] time_overhead_list = [] for model_name, (table_info, query_info, query_ratio_info, opt_time) in experiment_result.items(): for q_id, cost_val in query_info.items(): query_list.append([model_name, q_id, cost_val,query_ratio_info[q_id]]) for table, cost_dict in table_info.items(): shuffle_times = cost_dict.get('shuffle times', 0) hyper_times = cost_dict.get('hyper times', 0) join_cost = cost_dict.get('join cost', 0) table_list.append([model_name, table, shuffle_times, hyper_times, join_cost]) time_overhead_list.append([model_name, opt_time]) df_query = pd.DataFrame(query_list, columns=["Model", "Query ID", "Join Cost", "Scan Ratio"]) df_table = pd.DataFrame(table_list, columns=["Model", "Table", "Shuffle Times", "Hyper Times", "Join Cost"]) df_time = pd.DataFrame(time_overhead_list, columns=["Model", "Opt Time"]) pivot_query_cost = df_query.pivot_table(index="Query ID", columns="Model", values="Join Cost", aggfunc="first").reset_index() pivot_ratio = df_query.pivot_table(index="Query ID", columns="Model", values="Scan Ratio", aggfunc="first").reset_index() pivot_query_cost.rename(columns={ "PAW": "PAW_Join Cost", "ADP": "AD-MTO_Join Cost", "JT": "PAC-Tree_Join Cost" }, inplace=True) pivot_ratio.rename(columns={ "PAW": "PAW_Scan Ratio", "ADP": "AD-MTO_Scan Ratio", "JT": "PAC-Tree_Scan Ratio" }, inplace=True) merged_query_df = pd.merge(pivot_query_cost, pivot_ratio, on="Query ID") pivot_shuffle = df_table.pivot_table(index="Table", columns="Model", values="Shuffle Times", aggfunc="first").reset_index() pivot_hyper = df_table.pivot_table(index="Table", columns="Model", values="Hyper Times", aggfunc="first").reset_index() pivot_cost = df_table.pivot_table(index="Table", columns="Model", values="Join Cost", aggfunc="first").reset_index() pivot_shuffle.rename( columns={ "PAW": "PAW_shuffle_times", "ADP": "AD-MTO_shuffle_times", "JT": "PAC-Tree_shuffle_times", }, inplace=True ) pivot_hyper.rename( columns={ "PAW": "PAW_hyper_times", "ADP": "AD-MTO_hyper_times", "JT": "PAC-Tree_hyper_times" }, inplace=True ) merged_df = pivot_shuffle.merge(pivot_hyper, on="Table") merged_df = merged_df.merge(pivot_cost, on="Table", suffixes=(None, "_cost")) merged_df.rename( columns={ "PAW": "PAW_Scan Block", "ADP": "AD-MTO_Scan Block", "JT": "PAC-Tree_Scan Block", }, inplace=True ) os.makedirs(result_dir, exist_ok=True) res_saved_path=f"{result_dir}/{settings.benchmark}_results.xlsx" with pd.ExcelWriter(res_saved_path) as writer: merged_query_df.to_excel(writer, sheet_name="Query_Cost", index=False) merged_df.to_excel(writer, sheet_name="Table_Cost", index=False) df_time.to_excel(writer, sheet_name="Time_Cost", index=False) def pretty_print_dubgger_query_join_info(): from rich.console import Console from rich.table import Table from rich.panel import Panel console = Console() benchmarks = list(dubug_logger.query_cost_distributions.keys()) for benchmark in benchmarks: table = Table(show_header=True, header_style="bold cyan") table.add_column("QueryID", style="bold magenta") methods = sorted(list(dubug_logger.query_cost_distributions[benchmark].keys())) for method in methods: table.add_column(f"{method} Ratio", justify="right", style=get_method_style(method)) table.add_column(f"{method} Hyper", justify="right", style=get_method_style(method)) table.add_column(f"{method} Shuffle", justify="right", style=get_method_style(method)) all_qids = set() for method in methods: all_qids.update(dubug_logger.query_cost_distributions[benchmark][method].keys()) for qid in sorted(list(all_qids)): row = [str(qid)] for method in methods: if qid in dubug_logger.query_cost_distributions[benchmark][method]: scan_ratio,hyper_cost, shuffle_cost = dubug_logger.query_cost_distributions[benchmark][method][qid] row.append(str(scan_ratio)) row.append(hyper_cost) row.append(shuffle_cost) else: row.append("N/A") row.append("N/A") row.append("N/A") table.add_row(*row) panel = Panel(table, title=f"[bold yellow]Query Cost Distributions for {benchmark}[/bold yellow]", border_style="blue", expand=False) console.print(panel) console.print("\n") def get_method_style(method): """Return the color style for a given method.""" styles = { "paw": "green", "mto": "blue", "pac": "yellow", "adp": "cyan" } return styles.get(method.lower(), "white") def pretty_print_dubgger_table_join_info(): from rich.console import Console from rich.table import Table console = Console() table = Table(show_header=True, header_style="bold cyan") benchmarks = list(dubug_logger.join_cost_distributions.keys()) table.add_column("Benchmark", style="bold magenta") table.add_column("Table", style="dim") all_methods = set() for benchmark in benchmarks: all_methods.update(dubug_logger.join_cost_distributions[benchmark].keys()) method_list = sorted(list(all_methods)) # Sort to ensure consistent order for method in method_list: table.add_column(method, justify="right", style=get_method_style(method)) for benchmark in benchmarks: all_tables = set() benchmark_methods = dubug_logger.join_cost_distributions[benchmark].keys() for method in benchmark_methods: all_tables.update(dubug_logger.join_cost_distributions[benchmark][method].keys()) primary_method = next(iter(benchmark_methods)) sorted_tables = sorted( [t for t in all_tables if t != "avg"], key=lambda t: dubug_logger.join_cost_distributions[benchmark][primary_method].get(t, 0), reverse=True ) if "avg" in all_tables: sorted_tables.append("avg") for i, table_name in enumerate(sorted_tables): row = [] if i == 0: row.append(benchmark) else: row.append("") row.append(table_name) for method in method_list: if method in benchmark_methods: cost = dubug_logger.join_cost_distributions[benchmark][method].get(table_name, "N/A") row.append(f"{cost:,}" if isinstance(cost, (int, float)) else cost) else: row.append("N/A") if table_name == "avg": table.add_row(*row, style="bold") else: table.add_row(*row) console.print("\n[bold yellow]Join Cost Distributions By Method[/bold yellow]\n") console.print(table) def get_method_style(method): styles = { "paw": "green", "mto": "blue", "pac": "yellow", "adp": "cyan" } return styles.get(method.lower(), "white") def run_evaluation(save_result=False, disabled_prim_reorder=False): """ disabled_prim_reorder: 是否使用 基于prim算法的join reorder。若为True,则不使用join reorder """ load_tree_context() experiment_data = {} # logging.basicConfig(filename=f'{base_dir}/../logs/column_selection_experiment.log', # level=logging.INFO, # format='%(asctime)s - %(levelname)s - %(message)s') using_join_order_strategy='traditional' if disabled_prim_reorder else 'prim' def run_PAW(): res_PAW, cost_PAW, ratio_PAW, total_opt_time = select_columns_by_PAW(strategy='traditional', group_type=0) return [res_PAW, cost_PAW, ratio_PAW, total_opt_time] def run_AD_MTO(): res_ADP, cost_ADP, ratio_ADP, total_opt_time = select_columns_by_AD_MTO(strategy='traditional', group_type=0) return [res_ADP, cost_ADP, ratio_ADP, total_opt_time] def run_JT(): res_JT, cost_JT,ratio_JT, total_opt_time = select_columns_by_JT(group_type=1, strategy=using_join_order_strategy, disabled_joinkey_selection=False) return [res_JT, cost_JT, ratio_JT, total_opt_time] experiment_data["PAW"]=run_PAW() experiment_data["ADP"]=run_AD_MTO() experiment_data["JT"]=run_JT() if args.mode=='debug': pretty_print_dubgger_table_join_info() pretty_print_dubgger_query_join_info() # exit(-1) with concurrent.futures.ThreadPoolExecutor() as executor: future_paw = executor.submit(run_PAW) future_adp = executor.submit(run_AD_MTO) future_jt = executor.submit(run_JT) result_paw = future_paw.result() result_adp = future_adp.result() result_jt = future_jt.result() experiment_data["PAW"]=result_paw experiment_data["ADP"]=result_adp experiment_data["JT"]=result_jt print(experiment_data) if save_result: save_experiment_result(experiment_data, disabled_prim_reorder) def scaling_block_size(): for item in [1000, 5000, 20000, 50000]: settings.block_size=item run_evaluation(save_result=False, disabled_prim_reorder=True) def scaling_replication(): for item in [1,2,3,4]: settings.replication_factor=item run_evaluation(save_result=False, disabled_prim_reorder=False) def test_model_overhead_by_scaling_scale(): for sf in [5,10,20,50]: settings.scale_factor=sf # settings.block_size=int(10000/sf) run_evaluation(save_result=False, disabled_prim_reorder=False) if args.command == 0: #basic experiments settings.replication_factor=3 #defult run_evaluation(save_result=False, disabled_prim_reorder=False) elif args.command == 1: # sensitivity analysis scaling_block_size() elif args.command == 2: scaling_replication() elif args.command == 3: test_model_overhead_by_scaling_scale() else: print("Invalid command")