PAC-tree / model / join_key_selector.py
join_key_selector.py
Raw
from partition_algorithm import PartitionAlgorithm
import pickle
import itertools
import time
from join_eval import JoinEvaluator
import os
from join_order import JoinGraph,TraditionalJoinOrder
from conf.context import Context
import pandas as pd
import concurrent.futures
from db.conf import table_suffix
import argparse
from rich.console import Console
from rich.logging import RichHandler
import logging

base_dir=os.path.dirname(os.path.abspath(__file__))

parser = argparse.ArgumentParser()
parser.add_argument("--command",type=int, default=0, help="choose the command to run")
parser.add_argument("--benchmark",type=str, default='tpcds', help="choose the evaluated benchmark")
parser.add_argument("--mode", type=str, default='debug', help="choose the mode to run")
parser.add_argument("--init", type=bool, default=False, help="initialize base layouts and join queries")

args = parser.parse_args()
settings=Context()
settings.benchmark=args.benchmark  #'tpch' imdb
# settings.block_size=200  # 5000, 20000,50000

class Debugger:
    def __init__(self):
        self.join_cost_distributions = {}
        self.query_cost_distributions = {}


class TwoLayerTree:
    def __init__(self,trees):
        self.candidate_trees=trees

class QdTree:
    def __init__(self,trees):
        self.candidate_trees=trees 

# 超图基类
class BaseHypergraph:
    def __init__(self, benchmark, candidate_nodes):
        self.benchmark = benchmark
        self.candidate_nodes = candidate_nodes
        self.hyper_nodes = {table: [] for table in candidate_nodes}

    def clear_hyper_nodes(self):
        self.hyper_nodes = {table: [] for table in self.candidate_nodes}

    def show_selections(self):
        for table, columns in self.hyper_nodes.items():
            print(f"Hypernode: {table}, Selected Columns: {columns}")

# TPC-H 超图类
class HypergraphTPC(BaseHypergraph):
    def __init__(self):
        candidate_nodes = {
            "nation": ["n_nationkey", "n_regionkey"],
            "region": ["r_regionkey"],
            "customer": ["c_custkey", "c_nationkey"],
            "orders": ["o_orderkey", "o_custkey"],
            "lineitem": ["l_partkey", "l_orderkey", "l_suppkey"],
            "part": ["p_partkey"],
            "supplier": ["s_suppkey", "s_nationkey"],
            "partsupp": ["ps_suppkey", "ps_partkey"]
        }
        super().__init__('tpch', candidate_nodes)

# JOB 超图类
class JobHypergraph(BaseHypergraph):
    def __init__(self):
        candidate_nodes = {
            "aka_name": ["id", "person_id"],
            'aka_title': ["id", "movie_id", "kind_id"],
            "cast_info": ["id", "person_id", "movie_id", "person_role_id", "role_id"],
            "char_name": ["id"],
            "comp_cast_type": ["id"],
            "company_name": ["id"],
            "company_type": ["id"],
            "complete_cast": ["id", "movie_id", "subject_id", "status_id"],
            "info_type": ["id"],
            "keyword": ["id"],
            "kind_type": ["id"],
            "link_type": ["id"],
            "movie_companies": ["id", "movie_id", "company_id", "company_type_id"],
            "movie_info": ["id", "movie_id", "info_type_id"],
            "movie_info_idx": ["id", "movie_id", "info_type_id"],
            "movie_keyword": ["id", "movie_id", "keyword_id"],
            "movie_link": ["link_type_id", "movie_id", "linked_movie_id"],
            "name": ["id"],
            "person_info": ["id", "person_id", "info_type_id"],
            "role_type": ["id"],
            "title": ["id", "kind_id"],
        }
        super().__init__('imdb', candidate_nodes)

# TPC-DS 超图类
class HypergraphTPCDS(BaseHypergraph):
    def __init__(self):
        candidate_nodes = {
            # Tables that have joins in queries
            "catalog_sales": ["cs_bill_customer_sk", "cs_ship_customer_sk", "cs_sold_date_sk", 
                             "cs_item_sk", "cs_bill_cdemo_sk"],
            "customer": ["c_customer_sk", "c_current_addr_sk", "c_current_cdemo_sk"],
            "customer_address": ["ca_address_sk"],
            "customer_demographics": ["cd_demo_sk"],
            "date_dim": ["d_date_sk"],
            "household_demographics": ["hd_demo_sk"],
            "inventory": ["inv_date_sk", "inv_item_sk"],
            "item": ["i_item_sk"],
            "promotion": ["p_promo_sk"],
            "store": ["s_store_sk"],
            "store_returns": ["sr_customer_sk", "sr_item_sk", "sr_ticket_number", 
                             "sr_returned_date_sk"],
            "store_sales": ["ss_sold_date_sk", "ss_item_sk", "ss_customer_sk", 
                           "ss_cdemo_sk", "ss_hdemo_sk", "ss_addr_sk", 
                           "ss_store_sk", "ss_promo_sk", "ss_ticket_number"],
            "web_sales": ["ws_bill_customer_sk", "ws_item_sk", "ws_sold_date_sk"],
            # Tables that don't have joins in current query set, selecting their most likely join keys
            "call_center": ["cc_call_center_sk"],  # Primary key
            "catalog_page": ["cp_catalog_page_sk"],  # Primary key
            "catalog_returns": ["cr_item_sk", "cr_order_number", "cr_returned_date_sk"],  # Composite primary key and common foreign keys
            "income_band": ["ib_income_band_sk"],  # Primary key
            "reason": ["r_reason_sk"],  # Primary key
            "ship_mode": ["sm_ship_mode_sk"],  # Primary key
            "time_dim": ["t_time_sk"],  # Primary key
            "warehouse": ["w_warehouse_sk"],  # Primary key
            "web_page": ["wp_web_page_sk"],  # Primary key
            "web_returns": ["wr_item_sk", "wr_order_number", "wr_returned_date_sk"],  # Composite primary key and common foreign keys
            "web_site": ["web_site_sk"]  # Primary key
        }
        super().__init__('tpcds', candidate_nodes)

def create_join_trees(hypergraph):
    tree_dict={}
    partitioner=PartitionAlgorithm(benchmark=hypergraph.benchmark,block_size=settings.block_size)
    partitioner.load_join_query()
    for table, columns in hypergraph.candidate_nodes.items():
        tree_dict[table]={}
        partitioner.table_name=table
        partitioner.load_data()
        partitioner.load_query()
        for column in columns:
            partitioner.LogicalJoinTree(join_col=column)
            tree_dict[table][column]=partitioner.partition_tree
    return TwoLayerTree(tree_dict)

def create_color_logger(name='color_logger', level=logging.INFO):
    console = Console()
    logging.basicConfig(
        level=level,
        format="%(message)s",
        datefmt="[%X]",
        handlers=[RichHandler(console=console, rich_tracebacks=True)]
    )
    return logging.getLogger(name)

def create_qdtrees(hypergraph):
    tree_dict={}
    pa=PartitionAlgorithm(benchmark=hypergraph.benchmark,block_size=settings.block_size)
    pa.load_join_query(join_indeuced='PAW')
    for table, _ in hypergraph.candidate_nodes.items():
        pa.table_name=table
        pa.load_data()
        pa.load_query(join_indeuced='PAW')
        pa.InitializeWithQDT()
        tree_dict[table]=pa.partition_tree
    return QdTree(tree_dict)

def load_join_queries(hypergraph,is_join_indeuced):
    pa=PartitionAlgorithm(benchmark=hypergraph.benchmark,block_size=settings.block_size)
    pa.load_join_query(join_indeuced=is_join_indeuced)
    return pa.join_queries


dubug_logger=Debugger()
logger = create_color_logger()

def init_partitioner():
    if settings.benchmark=='tpch':
        hypergraph=HypergraphTPC()
    elif settings.benchmark=='imdb':
        hypergraph=JobHypergraph()
    elif settings.benchmark=='tpcds':
        hypergraph=HypergraphTPCDS()
    # Create join trees for all tables
    joinTreer = create_join_trees(hypergraph)
    join_trees_dir = f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/base_trees'
    os.makedirs(join_trees_dir, exist_ok=True)
    with open(f'{join_trees_dir}/join-trees.pkl', 'wb') as f:
        pickle.dump(joinTreer, f)

    qdTreer = create_qdtrees(hypergraph)
    qd_trees_dir = f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/base_trees'
    os.makedirs(qd_trees_dir, exist_ok=True)
    with open(f'{qd_trees_dir}/qd-trees.pkl', 'wb') as f:
        pickle.dump(qdTreer, f)

def init_workload():
    if settings.benchmark=='tpch':
        hypergraph=HypergraphTPC()
    elif settings.benchmark=='imdb':
        hypergraph=JobHypergraph()
    elif settings.benchmark=='tpcds':
        hypergraph=HypergraphTPCDS()
      
    queries_base_dir = f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/used_queries'
    os.makedirs(queries_base_dir, exist_ok=True)

    paw_queries = load_join_queries(hypergraph, is_join_indeuced='PAW')
    paw_queries_path = os.path.join(queries_base_dir, 'paw-queries.pkl')
    with open(paw_queries_path, 'wb') as f:
        pickle.dump(paw_queries, f)

    mto_queries = load_join_queries(hypergraph, is_join_indeuced='MTO')
    mto_queries_path = os.path.join(queries_base_dir, 'mto-queries.pkl')
    with open(mto_queries_path, 'wb') as f:
        pickle.dump(mto_queries, f)

    pac_queries = load_join_queries(hypergraph, is_join_indeuced='PAC')
    pac_queries_path = os.path.join(queries_base_dir, 'pac-queries.pkl')
    with open(pac_queries_path, 'wb') as f:
        pickle.dump(pac_queries, f)


# ~~~~~~~~loading global variables~~~~~~~~
if args.init:
    init_workload()
    init_partitioner()
    exit(0)
    
# if args.mode=='debug':
#     init_workload()

class PartitionContext:
    def __init__(self):
        self.joinTreer = None
        self.qdTreer = None
        self.paw_queries = None
        self.mto_queries = None
        self.pac_queries = None
        self.metadata = None
        self.table_metadata = None

context=PartitionContext()

def load_tree_context():
    context.joinTreer=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/base_trees/join-trees.pkl','rb'))
    context.qdTreer=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/base_trees/qd-trees.pkl','rb'))
    context.paw_queries=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/used_queries/paw-queries.pkl','rb'))
    context.mto_queries=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/used_queries/mto-queries.pkl','rb'))
    context.pac_queries=pickle.load(open(f'{base_dir}/../layouts/bs={settings.block_size}/{settings.benchmark}/used_queries/pac-queries.pkl','rb'))

    # ~~~~~~~~Load partial table metadata information~~~~~~~~
    context.metadata={}
    for table,tree in context.joinTreer.candidate_trees.items():
        sample_tree=list(tree.values())[0]
        context.metadata[table]={'row_count':sample_tree.pt_root.node_size,'used_cols':sample_tree.used_columns}    
    context.table_metadata=pickle.load(open(f'{base_dir}/../dataset/{settings.benchmark}/metadata.pkl','rb'))
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

def estimate_join_cost(hypergraph,current_table):
    shuffle_times,hyper_times={},{}
    join_cost=0
    for q_id, join_query in enumerate(context.mto_queries):
        for join_op in join_query['join_relations']:
            left_table,left_col_idx=list(join_op.items())[0]
            right_table,right_col_idx=list(join_op.items())[1]
            left_col,right_col=context.metadata[left_table]["used_cols"][left_col_idx],context.metadata[right_table]["used_cols"][right_col_idx]
            if left_table!=current_table or right_table!=current_table:
                continue
            left_query,right_query=join_query['vectors'][left_table],join_query['vectors'][right_table]
            is_hyper=True
            if left_col in hypergraph.hyper_nodes[left_table]:
                left_tree=context.joinTreer.candidate_trees[left_table][left_col]
            else:
                left_tree=list(context.joinTreer.candidate_trees[left_table].values())[0]
                is_hyper=False
            if right_col in hypergraph.hyper_nodes[right_table]:
                right_tree=context.joinTreer.candidate_trees[right_table][right_col]
            else:
                right_tree=list(context.joinTreer.candidate_trees[right_table].values())[0]
                is_hyper=False
            left_cost=sum([left_tree.nid_node_dict[bid].node_size  for bid in left_tree.query_single(left_query)])
            right_cost=sum([right_tree.nid_node_dict[bid].node_size  for bid in right_tree.query_single(right_query)])
            join_cost+=left_cost+right_cost if is_hyper else (left_cost+3*right_cost)
    
    return join_cost

def _should_skip_query(q_id):
    # skip query 7 in imdb and query 6 in tpch (invalid ids)
    return (settings.benchmark == 'imdb' and q_id == 7) or \
           (settings.benchmark == 'tpch' and q_id == 6) or \
           (settings.benchmark == 'tpcds' and q_id == -1)

def processing_join_workload(hypergraph, group_type, join_strategy='traditional', tree_type='paw'):
    """
    Process join workload and calculate the positions and total cost of hyper edges and shuffle edges in the current hypergraph
    
    Args:
        hypergraph: Hypergraph structure
        group_type: Group type (0 or 1)
        join_strategy: Join strategy ('traditional' or 'prim')
        tree_type: Tree type ('paw', 'mto' or 'pac')
        
    Returns:
        final_join_result: Table-level join result statistics
        query_cost_dict: Query-level cost statistics
        query_ratio_dict: Query-level ratio statistics
        opt_time: Optimization time statistics
    """
    # Initialize statistics trackers
    shuffle_times, hyper_times = {}, {}
    join_cost_dict = {}
    cost_log, byte_log = {}, {}
    solved_tables = {}
    shuffle_group_time, hyper_group_time = [], []
    tot_order_time = 0
    
    # Select appropriate query set
    input_queries = context.paw_queries if tree_type == 'paw' else \
                  context.mto_queries if tree_type == 'mto' else context.pac_queries
    
    # ========== Process each query ==========
    for q_id, join_query in enumerate(input_queries):
        print('processing query:', q_id)

        if _should_skip_query(q_id) or join_query['vectors'] == {}:
            if join_query['vectors'] == {}:
                print(f'empty query : {settings.benchmark} {q_id}')
            continue
        
        query_cost_log, query_bytes_log = [], []
        solved_tables[q_id] = set()
        
        # ---------- Process queries without join relations ----------
        if not join_query['join_relations']:
            print('no join operation')
            for table, query_vector in join_query['vectors'].items():
                solved_tables[q_id].add(table)
                
                # Select appropriate tree
                if tree_type == 'paw':
                    tree = context.qdTreer.candidate_trees[table]
                else:
                    tree = list(context.joinTreer.candidate_trees[table].values())[0]
                
                # Calculate scan cost
                b_ids = tree.query_single(query_vector)
                scan_cost = sum([tree.nid_node_dict[b_id].node_size for b_id in b_ids])
                join_cost_dict[table] = join_cost_dict.get(table, 0) + scan_cost
                query_cost_log.append(scan_cost)
                query_bytes_log.append(context.table_metadata[table]['read_line'] * scan_cost)
                logger.info("Just scan table: %s, cost: %s", table, scan_cost)
                
            cost_log[q_id] = query_cost_log
            byte_log[q_id] = query_bytes_log
            continue
        
        # ---------- Process queries with join relations ----------
        # 1. Determine join order
        start1_time = time.time()
        join_path = []
        
        if join_strategy == 'traditional':
            join_path = TraditionalJoinOrder(join_query['join_relations'], context.metadata).paths
        else:
            # Build scan block dictionary for Prim algorithm
            scan_block_dict = {'card': {}, 'relation': []}
            
            for join_op in join_query['join_relations']:
                left_table, left_col_idx = list(join_op.items())[0]
                right_table, right_col_idx = list(join_op.items())[1]
                left_col = context.metadata[left_table]["used_cols"][left_col_idx]
                right_col = context.metadata[right_table]["used_cols"][right_col_idx]
                left_query = join_query['vectors'][left_table]
                right_query = join_query['vectors'][right_table]
                is_hyper = True
                
                # Process PAW tree type
                if tree_type == 'paw':
                    if f'{left_table}.{left_col}' not in scan_block_dict["card"]:
                        scan_block_dict["card"][f'{left_table}.{left_col}'] = len(context.qdTreer.candidate_trees[left_table].query_single(left_query))
                    
                    if f'{right_table}.{right_col}' not in scan_block_dict["card"]:
                        scan_block_dict["card"][f'{right_table}.{right_col}'] = len(context.qdTreer.candidate_trees[right_table].query_single(right_query))
                    is_hyper = False
                # Process MTO or PAC tree type
                else:
                    # Process left table tree
                    if left_col in hypergraph.hyper_nodes[left_table]:
                        tree = context.joinTreer.candidate_trees[left_table][left_col]
                    else:
                        tree = list(context.joinTreer.candidate_trees[left_table].values())[0]
                        is_hyper = False
                    
                    if f'{left_table}.{left_col}' not in scan_block_dict["card"]:
                        scan_block_dict["card"][f'{left_table}.{left_col}'] = len(tree.query_single(left_query))
                    
                    # Process right table tree
                    if right_col in hypergraph.hyper_nodes[right_table]:
                        tree = context.joinTreer.candidate_trees[right_table][right_col]
                    else:
                        tree = list(context.joinTreer.candidate_trees[right_table].values())[0]
                        is_hyper = False
                    
                    if f'{right_table}.{right_col}' not in scan_block_dict["card"]:
                        scan_block_dict["card"][f'{right_table}.{right_col}'] = len(tree.query_single(right_query))
                
                scan_block_dict["relation"].append([
                    f'{left_table}.{left_col}',
                    f'{right_table}.{right_col}',
                    'Hyper' if is_hyper else 'Shuffle'
                ])
            
            # Generate MST using JoinGraph
            jg = JoinGraph(scan_block_dict,hypergraph.hyper_nodes)
            join_path = jg.generate_MST()
        
        tot_order_time += time.time() - start1_time  # Calculate join reorder time
        
        # 2. Process join path
        temp_table_dfs = []
        temp_joined_bytes = []
        temp_shuffle_ops = []
        
        # Iterate through each item in join path
        for order_id, item in enumerate(join_path):
            join_queryset, joined_cols, joined_trees, joined_tables = [], [], [], []
            is_hyper = True if item[2] == 1 else False
            
            is_real_hyper = True
            join_ops = [
                (item[0].table, item[0].adj_col[item[1]]),
                (item[1].table, item[1].adj_col[item[0]])
            ]
            
            # Collect join operation information
            for join_table, join_col in join_ops:
                join_queryset.append(join_query['vectors'][join_table])
                join_col_idx = context.metadata[join_table]["used_cols"].index(join_col)
                joined_cols.append(join_col_idx)
                
                # Select appropriate tree
                if tree_type == 'paw':
                    tree = context.qdTreer.candidate_trees[join_table]
                    is_real_hyper = False
                else:
                    if join_col in hypergraph.hyper_nodes[join_table]:
                        tree = context.joinTreer.candidate_trees[join_table][join_col]
                    else:
                        tree = list(context.joinTreer.candidate_trees[join_table].values())[0]
                        is_real_hyper = False
                
                joined_trees.append(tree)
                joined_tables.append(join_table)
            
            # Process shuffle join
            if not is_hyper:
                cur_idx = 0 if order_id == 0 else 1
                
                # Current is last path or next is not hyper
                if order_id == len(join_path) - 1 or join_path[order_id+1][2] == -1:
                    tree = joined_trees[cur_idx]
                    cur_table = joined_tables[cur_idx]
                    solved_tables[q_id].add(cur_table)
                    
                    # Get table data
                    b_ids = tree.query_single(join_queryset[cur_idx])
                    table_dataset = []
                    for b_id in b_ids:
                        table_dataset += list(tree.nid_node_dict[b_id].dataset)
                    
                    # Set column names
                    if hypergraph.benchmark == 'imdb':
                        used_columns = [table_suffix[hypergraph.benchmark][cur_table]+"_"+col for col in tree.used_columns]
                    else:
                        used_columns = tree.used_columns
                    
                    temp_table_dfs.append(pd.DataFrame(table_dataset, columns=used_columns))
                
                temp_shuffle_ops.append((
                    item[0].table, item[0].adj_col[item[1]], 
                    item[1].table, item[1].adj_col[item[0]]
                ))
                temp_joined_bytes.append(context.table_metadata[item[1].table]['read_line'])
                continue
            
            # Process hyper join
            # Ensure smaller table is on the left (optimization)
            if len(joined_trees[0].query_single(join_queryset[0])) > len(joined_trees[1].query_single(join_queryset[1])):
                join_queryset[0], join_queryset[1] = join_queryset[1], join_queryset[0]
                joined_cols[0], joined_cols[1] = joined_cols[1], joined_cols[0]
                joined_trees[0], joined_trees[1] = joined_trees[1], joined_trees[0]
                joined_tables[0], joined_tables[1] = joined_tables[1], joined_tables[0]
            
            left_table, right_table = joined_tables[0], joined_tables[1]
            
            # Update statistics
            if is_real_hyper:
                hyper_times[left_table] = hyper_times.get(left_table, 0) + 1
                hyper_times[right_table] = hyper_times.get(right_table, 0) + 1
            else:
                shuffle_times[left_table] = shuffle_times.get(left_table, 0) + 1
                shuffle_times[right_table] = shuffle_times.get(right_table, 0) + 1
            
            solved_tables[q_id].add(left_table)
            solved_tables[q_id].add(right_table)
            
            # Execute join evaluation
            join_eval = JoinEvaluator(
                join_queryset, joined_cols, joined_trees, joined_tables,
                settings.block_size, context.table_metadata, benchmark=hypergraph.benchmark
            )
            hyper_shuffle_cost, hyper_shuffle_read_bytes, temp_joined_df, group_time = join_eval.compute_total_shuffle_hyper_cost(group_type, is_real_hyper)
            
            # Update statistics
            join_cost_dict[left_table] = join_cost_dict.get(left_table, 0) + hyper_shuffle_cost // 2
            join_cost_dict[right_table] = join_cost_dict.get(right_table, 0) + hyper_shuffle_cost // 2
            if is_real_hyper:
                hyper_group_time.append(group_time)
            else:
                shuffle_group_time.append(group_time)
            query_cost_log.append(hyper_shuffle_cost)
            query_bytes_log.append(hyper_shuffle_read_bytes)
            logger.info("Join %s and %s, %s hyper cost: %s", left_table, right_table, '' if is_real_hyper else 'fake',hyper_shuffle_cost)
            temp_table_dfs.append(temp_joined_df)
            temp_joined_bytes.append(context.table_metadata[left_table]['read_line'] +context.table_metadata[right_table]['read_line'])
        
        # 3. Process multi-table join
        if len(temp_table_dfs) > 1:
            query_cost_log.append(-1)  # Mark start of multi-table join
            shuffle_cost = 0
            shuffle_bytes = 0
            last_temp_dfs = temp_table_dfs[0]
            last_temp_line = temp_joined_bytes[0]
            
            # Define hash join function
            def pandas_hash_join(df_A, join_col_A, df_B, join_col_B):
                if join_col_B in df_A.columns:
                    return df_A
                if df_A.shape[0] > df_B.shape[0]:
                    df_A, df_B = df_B, df_A
                    join_col_A, join_col_B = join_col_B, join_col_A
                df_B = df_B.drop_duplicates(subset=[join_col_B], keep='first')
                merged_df = df_A.merge(df_B, how='inner', left_on=join_col_A, right_on=join_col_B)
                return merged_df
            
            # Execute multi-table join
            for i in range(1, len(temp_table_dfs)):
                
                join_table1, join_col1, join_table2, join_col2 = temp_shuffle_ops[i-1]
                
                sample_rate1,sample_rate2 = 1, 1

                if join_col1 in context.joinTreer.candidate_trees[join_table1] and hasattr(context.joinTreer.candidate_trees[join_table1][join_col1], 'sample_rate'):
                    sample_rate1=context.joinTreer.candidate_trees[join_table1][join_col1].sample_rate

                if join_col2 in context.joinTreer.candidate_trees[join_table2] and hasattr(context.joinTreer.candidate_trees[join_table2][join_col2], 'sample_rate'):
                    sample_rate2=context.joinTreer.candidate_trees[join_table2][join_col2].sample_rate
                
                data_sample_rate=min(sample_rate1,sample_rate2)
                
                # Calculate cost
                shuffle_cost += int((temp_table_dfs[i].shape[0] * 3 + last_temp_dfs.shape[0]* 1)/data_sample_rate)
                shuffle_bytes += int((temp_table_dfs[i].shape[0] * 3 * last_temp_line + 
                                 last_temp_dfs.shape[0]* 1 * temp_joined_bytes[i])/data_sample_rate)
                
                if hypergraph.benchmark == 'imdb':
                    join_col1 = table_suffix[hypergraph.benchmark][join_table1] + "_" + join_col1
                    join_col2 = table_suffix[hypergraph.benchmark][join_table2] + "_" + join_col2
                
                # Execute join
                last_temp_dfs = pandas_hash_join(last_temp_dfs, join_col1, temp_table_dfs[i], join_col2)
                last_temp_line = temp_joined_bytes[i] + last_temp_line
                
                # Update statistics
                shuffle_times[join_table1] = shuffle_times.get(join_table1, 0) + 1
                shuffle_times[join_table2] = shuffle_times.get(join_table2, 0) + 1
                join_cost_dict[join_table1] = join_cost_dict.get(join_table1, 0) + shuffle_cost // 2
                join_cost_dict[join_table2] = join_cost_dict.get(join_table2, 0) + shuffle_cost // 2
                
                logger.info("Join %s and %s, shuffle cost: %s", join_table1, join_table2, shuffle_cost)
            
            query_cost_log.append(shuffle_cost)
            query_bytes_log.append(shuffle_bytes)
        
        cost_log[q_id] = query_cost_log
        byte_log[q_id] = query_bytes_log
    
    # ========== Summarize results ==========
    # 1. Calculate query costs and ratios
    query_cost_dict = {}
    query_ratio_dict = {}
    
    for qid in cost_log.keys():
        query_cost_dict[qid] = sum(list(cost_log[qid]))
        query_ratio_dict[qid] = sum(list(byte_log[qid])) / sum(
            [sum(context.table_metadata[table]['width'].values()) * context.table_metadata[table]['rows'] 
             for table in solved_tables[qid]]
        )
        print(f"Query {qid} cost log: {cost_log[qid]}  |  byte log: {byte_log[qid]}")
    
    print(f"Query ratio log: {query_ratio_dict}")
    
    if args.mode=='debug':
        print(f"Join cost distributions: {join_cost_dict}")
        if settings.benchmark not in dubug_logger.join_cost_distributions:
            dubug_logger.join_cost_distributions[settings.benchmark]={}
        dubug_logger.join_cost_distributions[settings.benchmark][tree_type] = join_cost_dict
        
        
        if settings.benchmark not in dubug_logger.query_cost_distributions:
            dubug_logger.query_cost_distributions[settings.benchmark]={}
        dubug_logger.query_cost_distributions[settings.benchmark][tree_type] = {}
        for qid in cost_log.keys():
            # Find index of -1 in cost_log, join elements before it as string, join elements after it as string, separated by space
            if -1 in cost_log[qid]:
                idx = cost_log[qid].index(-1)
                dubug_logger.query_cost_distributions[settings.benchmark][tree_type][qid]=[round(query_ratio_dict[qid],4),' '.join([str(i) for i in cost_log[qid][:idx]]),''.join([str(i) for i in cost_log[qid][idx+1:]])]
        
    # 2. Summarize table-level results
    final_join_result = {}
    for table in join_cost_dict.keys():
        final_join_result[table] = {
            'shuffle times': shuffle_times.get(table, 0),
            'hyper times': hyper_times.get(table, 0),
            'join cost': join_cost_dict[table]
        }
    
    # 3. Calculate averages
    query_cost_dict['avg'] = sum(list(query_cost_dict.values())) / len(query_cost_dict.keys())
    query_ratio_dict['avg'] = sum(list(query_ratio_dict.values())) / len(query_ratio_dict.keys())
    
    table_num=len(final_join_result.keys())
    
    shuffle_avg = sum(list(shuffle_times.values())) / table_num if shuffle_times else 0
    hyper_avg = sum(list(hyper_times.values())) / table_num if hyper_times else 0
    cost_avg = sum(list(join_cost_dict.values())) / table_num if join_cost_dict else 0
    
    final_join_result['avg'] = {
        'shuffle times': shuffle_avg,
        'hyper times': hyper_avg,
        'join cost': cost_avg
    }
    
    # 4. Calculate optimization time
    opt_time = [
        tot_order_time,
        sum(hyper_group_time) / len(hyper_group_time) if len(hyper_group_time) > 0 else 0,
        sum(shuffle_group_time) / len(shuffle_group_time) if len(shuffle_group_time) > 0 else 0
    ]
    
    return final_join_result, query_cost_dict, query_ratio_dict, opt_time


def processing_join_workload_old2(hypergraph,group_type,join_strategy='traditional',tree_type='paw'):
    shuffle_times,hyper_times={},{}
    join_cost_dict={}
    cost_log={}
    byte_log={}
    solved_tables={}
    input_queries=context.paw_queries if tree_type=='paw' else context.mto_queries if tree_type=='mto' else context.pac_queries
    shuffle_group_time,hyper_group_time=[],[]
    tot_order_time=0
    for q_id,join_query in enumerate(input_queries):
        print('processing query:',q_id)
        if _should_skip_query(q_id):
            continue
        if join_query['vectors']=={}:
            print(f'empty query : {settings.benchmark} {q_id}')
            continue
        query_cost_log,query_bytes_log=list(),list()
        solved_tables[q_id]=set()
        if join_query['join_relations']:
            start1_time=time.time()
            if join_strategy=='traditional':
                join_path=TraditionalJoinOrder(join_query['join_relations'],context.metadata).paths
            else:
                scan_block_dict={'card':{},'relation':[]}
                for join_op in join_query['join_relations']:
                    left_table,left_col_idx=list(join_op.items())[0]
                    right_table,right_col_idx=list(join_op.items())[1]
                    left_col,right_col=context.metadata[left_table]["used_cols"][left_col_idx],context.metadata[right_table]["used_cols"][right_col_idx]
                    left_query=join_query['vectors'][left_table]
                    right_query=join_query['vectors'][right_table]
                    is_hyper=True
                    if tree_type=='paw':
                        if f'{left_table}.{left_col}' not in scan_block_dict["card"]:
                            scan_block_dict["card"][f'{left_table}.{left_col}']=len(context.qdTreer.candidate_trees[left_table].query_single(left_query))
                        
                        if f'{right_table}.{right_col}' not in scan_block_dict["card"]:
                            scan_block_dict["card"][f'{right_table}.{right_col}']=len(context.qdTreer.candidate_trees[right_table].query_single(right_query))
                        is_hyper=False
                    else:
                        if left_col in hypergraph.hyper_nodes[left_table]:
                            tree=context.joinTreer.candidate_trees[left_table][left_col]
                        else:
                            tree=list(context.joinTreer.candidate_trees[left_table].values())[0]
                            is_hyper=False
                        if f'{left_table}.{left_col}' not in scan_block_dict["card"]:
                            scan_block_dict["card"][f'{left_table}.{left_col}']=len(tree.query_single(left_query))
                        
                        if right_col in hypergraph.hyper_nodes[right_table]:
                            tree=context.joinTreer.candidate_trees[right_table][right_col]
                        else:
                            tree=list(context.joinTreer.candidate_trees[right_table].values())[0]
                            is_hyper=False
                        if f'{right_table}.{right_col}' not in scan_block_dict["card"]:
                            scan_block_dict["card"][f'{right_table}.{right_col}']=len(tree.query_single(right_query))
                    
                    scan_block_dict["relation"].append([f'{left_table}.{left_col}',f'{right_table}.{right_col}','Hyper' if is_hyper else 'Shuffle'])
                jg=JoinGraph(scan_block_dict)
                join_path=jg.generate_MST()
            tot_order_time+=time.time()-start1_time   # 计算join reorder的时间
            temp_table_dfs=[]
            temp_joined_bytes=[]
            temp_shuffle_ops=[]
            
            for order_id,item in enumerate(join_path):
                join_queryset,joined_cols,joined_trees,joined_tables=[],[],[],[]
                is_hyper=True if item[2]==1 else False
                
                is_real_hyper=True
                join_ops=[(item[0].table,item[0].adj_col[item[1]]),(item[1].table,item[1].adj_col[item[0]])]
                
                for join_table,join_col in join_ops:
                    join_queryset.append(join_query['vectors'][join_table])
                    join_col_idx=context.metadata[join_table]["used_cols"].index(join_col)
                    joined_cols.append(join_col_idx)
                    if tree_type=='paw':
                        tree=context.qdTreer.candidate_trees[join_table]
                        is_real_hyper=False
                    else:
                        if join_col in hypergraph.hyper_nodes[join_table]:
                            tree=context.joinTreer.candidate_trees[join_table][join_col]
                        else:
                            tree=list(context.joinTreer.candidate_trees[join_table].values())[0]
                            is_real_hyper=False
                    joined_trees.append(tree)
                    joined_tables.append(join_table)
                
                if not is_hyper:
                    tree=None
                    cur_idx=0 if order_id==0 else 1
                    if order_id==len(join_path)-1 or join_path[order_id+1][2]==-1:
                        tree=joined_trees[cur_idx]
                        cur_table=joined_tables[cur_idx]
                        solved_tables[q_id].add(cur_table)
                        b_ids=tree.query_single(join_queryset[cur_idx])
                        table_dataset=[]
                        for b_id in b_ids:
                            table_dataset+=list(tree.nid_node_dict[b_id].dataset)
                        if hypergraph.benchmark=='imdb':
                            used_columns=[table_suffix[hypergraph.benchmark][cur_table]+"_"+col for col in tree.used_columns]
                        else:
                            used_columns=tree.used_columns
                        temp_table_dfs.append(pd.DataFrame(table_dataset,columns=used_columns))
                    temp_shuffle_ops.append((item[0].table,item[0].adj_col[item[1]],item[1].table,item[1].adj_col[item[0]]))
                    temp_joined_bytes.append(context.table_metadata[item[1].table]['read_line'])
                    continue
                else:
                    if len(joined_trees[0].query_single(join_queryset[0]))>len(joined_trees[1].query_single(join_queryset[1])):
                        join_queryset[0],join_queryset[1]=join_queryset[1],join_queryset[0]
                        joined_cols[0],joined_cols[1]=joined_cols[1],joined_cols[0]
                        joined_trees[0],joined_trees[1]=joined_trees[1],joined_trees[0]
                        joined_tables[0],joined_tables[1]=joined_tables[1],joined_tables[0]
                    left_table,right_table=joined_tables[0],joined_tables[1]

                    if is_real_hyper:
                        hyper_times[left_table]=hyper_times.get(left_table,0)+1
                        hyper_times[right_table]=hyper_times.get(left_table,0)+1
                    else:
                        shuffle_times[left_table]=shuffle_times.get(left_table,0)+1
                        shuffle_times[right_table]=shuffle_times.get(left_table,0)+1
                    
                    solved_tables[q_id].add(left_table)
                    solved_tables[q_id].add(right_table)
                    
                    join_eval = JoinEvaluator(join_queryset,joined_cols,joined_trees,joined_tables,settings.block_size,context.table_metadata,benchmark=hypergraph.benchmark)
                    hyper_shuffle_cost,hyper_shuffle_read_bytes, temp_joined_df, group_time=join_eval.compute_total_shuffle_hyper_cost(group_type,is_real_hyper)
                    join_cost_dict[left_table]=join_cost_dict.get(left_table,0)+hyper_shuffle_cost//2
                    join_cost_dict[right_table]=join_cost_dict.get(right_table,0)+hyper_shuffle_cost//2
                    if is_real_hyper:
                        hyper_group_time.append(group_time)
                    else:
                        shuffle_group_time.append(group_time)
                    query_cost_log.append(hyper_shuffle_cost)
                    query_bytes_log.append(hyper_shuffle_read_bytes)
                    temp_table_dfs.append(temp_joined_df)
                    temp_joined_bytes.append(context.table_metadata[left_table]['read_line']+context.table_metadata[right_table]['read_line'])
            
            if len(temp_table_dfs)>1:
                query_cost_log.append(-1)
                shuffle_cost=0
                shuffle_bytes=0
                last_temp_dfs=temp_table_dfs[0]
                last_temp_line=temp_joined_bytes[0]
                def pandas_hash_join(df_A, join_col_A, df_B, join_col_B):
                    if join_col_B in df_A.columns:
                        return df_A
                    if df_A.shape[0] > df_B.shape[0]:
                        df_A, df_B = df_B, df_A
                        join_col_A, join_col_B = join_col_B, join_col_A
                    df_B.drop_duplicates(subset=[join_col_B], keep='first')
                    merged_df = df_A.merge(df_B, how='inner', left_on=join_col_A, right_on=join_col_B)
                    return merged_df
                
                for i in range(1,len(temp_table_dfs)):
                    shuffle_cost+=(temp_table_dfs[i].shape[0]*3+last_temp_dfs.shape[0])
                    shuffle_bytes+=(temp_table_dfs[i].shape[0]*3*last_temp_line+last_temp_dfs.shape[0]*temp_joined_bytes[i])
                    join_table1,join_col1,join_table2,join_col2=temp_shuffle_ops[i-1]
                    if hypergraph.benchmark=='imdb':
                        join_col1=table_suffix[hypergraph.benchmark][join_table1]+"_"+join_col1
                        join_col2=table_suffix[hypergraph.benchmark][join_table2]+"_"+join_col2
                    last_temp_dfs=pandas_hash_join(last_temp_dfs,join_col1,temp_table_dfs[i],join_col2)
                    last_temp_line=temp_joined_bytes[i]+last_temp_line

                    shuffle_times[join_table1]=shuffle_times.get(join_table1,0)+1
                    shuffle_times[join_table2]=shuffle_times.get(join_table2,0)+1
                    join_cost_dict[join_table1]=join_cost_dict.get(join_table1,0)+shuffle_cost//2
                    join_cost_dict[join_table2]=join_cost_dict.get(join_table2,0)+shuffle_cost//2
                query_cost_log.append(shuffle_cost)
                query_bytes_log.append(shuffle_bytes)
        else:
            print('no join operation')
            for table, query_vector in join_query['vectors'].items():
                solved_tables[q_id].add(table)
                if tree_type=='paw':
                    tree=context.qdTreer.candidate_trees[table]
                else:
                    tree=list(context.joinTreer.candidate_trees[table].values())[0]
                b_ids=tree.query_single(query_vector)
                scan_cost=sum([tree.nid_node_dict[b_id].node_size for b_id in b_ids])
                
                join_cost_dict[table]=join_cost_dict.get(table,0)+scan_cost
                query_cost_log.append(scan_cost)
                query_bytes_log.append(context.table_metadata[table]['read_line']*scan_cost)
        cost_log[q_id]=query_cost_log
        byte_log[q_id]=query_bytes_log
    
    query_cost_dict={}
    query_ratio_dict={}
    for qid in cost_log.keys():
        query_cost_dict[qid]=sum(list(cost_log[qid]))
        query_ratio_dict[qid]=sum(list(byte_log[qid]))/sum([sum(context.table_metadata[table]['width'].values())*context.table_metadata[table]['rows'] for table in solved_tables[qid]])
        print(f"Query {qid} cost log: {cost_log[qid]}  |  ratio log: {byte_log[qid]}")    
    print(f"Query ratio log: {query_ratio_dict}")    
    final_join_result={}
    for table in join_cost_dict.keys():
        final_join_result[table]={'shuffle times':shuffle_times.get(table,0),'hyper times':hyper_times.get(table,0),'join cost':join_cost_dict[table]}
        
    query_cost_dict['avg']=sum(list(query_cost_dict.values()))/len(query_cost_dict.keys())
    query_ratio_dict['avg']=sum(list(query_ratio_dict.values()))/len(query_ratio_dict.keys())
    final_join_result['avg']={'shuffle times':sum(list(shuffle_times.values()))/len(shuffle_times.keys()) if shuffle_times else 0,'hyper times':sum(list(hyper_times.values()))/len(hyper_times.keys()) if hyper_times else 0,'join cost':sum(list(join_cost_dict.values()))/len(join_cost_dict.keys()) if join_cost_dict else 0}
    
    opt_time=[tot_order_time,sum(hyper_group_time)/len(hyper_group_time) if len(hyper_group_time)>0 else 0,sum(shuffle_group_time)/len(shuffle_group_time) if len(shuffle_group_time)>0 else 0]
    return final_join_result,query_cost_dict,query_ratio_dict,opt_time


def select_columns_by_PAW(group_type=0, strategy='traditional'):
    # Build hypergraph structure
    hypergraph=None
    if settings.benchmark=='tpch':
        hypergraph = HypergraphTPC() 
    elif settings.benchmark=='imdb':
        hypergraph = JobHypergraph()
    elif settings.benchmark=='tpcds':
        hypergraph = HypergraphTPCDS()
    total_opt_time=[]
    tree_opt_time=0
    for table in hypergraph.hyper_nodes:
        tree_opt_time+=context.qdTreer.candidate_trees[table].build_time
    total_opt_time.append(tree_opt_time)
    final_result, query_costs, query_ratio, group_opt_time = processing_join_workload(hypergraph, group_type, strategy, tree_type='paw')
    # total_opt_time+=group_opt_time
    print("Final Column Selection Cost of PAW is: ",sum(list(query_costs.values()))/len(query_costs.keys()))
    print("Average Data Cost Ratio of PAW is: ",sum(list(query_ratio.values()))/len(query_ratio.keys()))
    print("Total Opt Time of PAW is: ",total_opt_time)
    hypergraph.show_selections()
    hypergraph_save_path=f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/join_key_selection/paw-hgraph.pkl'
    os.makedirs(os.path.dirname(hypergraph_save_path), exist_ok=True)
    
    with open(hypergraph_save_path,'wb') as f:
        pickle.dump(hypergraph,f)
    return final_result, query_costs,query_ratio, sum(total_opt_time)
    

def select_columns_by_AD_MTO(group_type=0, strategy='traditional'):
    hypergraph=None
    if settings.benchmark=='tpch':
        hypergraph = HypergraphTPC() 
    elif settings.benchmark=='imdb':
        hypergraph = JobHypergraph()
    elif settings.benchmark=='tpcds':
        hypergraph = HypergraphTPCDS()
    total_opt_time=[]
    tree_opt_time=0
    for table in hypergraph.hyper_nodes:
        col=context.metadata[table]['used_cols'][0]
        if col in hypergraph.candidate_nodes[table]:
            hypergraph.hyper_nodes[table].append(col)
            tree_opt_time+=context.joinTreer.candidate_trees[table][col].build_time
    total_opt_time.append(tree_opt_time)
    final_result, query_costs, query_ratio, group_opt_time = processing_join_workload(hypergraph, group_type, strategy, tree_type='mto')
    # total_opt_time+=group_opt_time
    print("Final Column Selection Cost of AD-MTO is: ",sum(list(query_costs.values()))/len(query_costs.keys()))
    print("Average Data Cost Ratio of AD-MTO is: ",sum(list(query_ratio.values()))/len(query_ratio.keys()))
    print("Total Opt Time of AD-MTO is: ",total_opt_time)
    hypergraph.show_selections()
    hypergraph_save_path=f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/join_key_selection/ad-mto-hgraph.pkl'
    os.makedirs(os.path.dirname(hypergraph_save_path), exist_ok=True)
    with open(hypergraph_save_path,'wb') as f:
        pickle.dump(hypergraph,f)

    return final_result, query_costs, query_ratio, sum(total_opt_time)

def select_columns_by_JT(group_type,strategy='traditional',disabled_joinkey_selection=False):

    adj_matrix = {}
    for query in context.pac_queries:
        for join_op in query["join_relations"]:
            join_relations=set()
            for table,col in join_op.items():
                join_relations.add(f'{table}.{context.metadata[table]["used_cols"][col]}')
            for key1, key2 in itertools.combinations(join_relations, 2):
                if key1 not in adj_matrix:
                    adj_matrix[key1] = {}
                if key2 not in adj_matrix:
                    adj_matrix[key2] = {}
                adj_matrix[key1][key2] = adj_matrix[key1].get(key2, 0) + 1
                adj_matrix[key2][key1] = adj_matrix[key2].get(key1, 0) + 1
    
    hypergraph=None
    if settings.benchmark=='tpch':
        hypergraph = HypergraphTPC() 
    elif settings.benchmark=='imdb':
        hypergraph = JobHypergraph()
    elif settings.benchmark=='tpcds':
        hypergraph = HypergraphTPCDS()
    total_opt_time=[]


    tree_opt_time=0
    if disabled_joinkey_selection:
        hypergraph.clear_hyper_nodes()
        for table in hypergraph.hyper_nodes:
            col=context.metadata[table]['used_cols'][0]
            if col in hypergraph.candidate_nodes[table]:
                hypergraph.hyper_nodes[table].append(col)
                tree_opt_time+=context.joinTreer.candidate_trees[table][col].build_time       
    else:
        for table in hypergraph.candidate_nodes:
            col_freq_dict={}
            for source_col in hypergraph.candidate_nodes[table]:
                if f'{table}.{source_col}' in adj_matrix:
                    for key,freq in adj_matrix[f'{table}.{source_col}'].items():
                        targe_table=key.split('.')[0]
                        tot_freq=freq*context.table_metadata[targe_table]['rows']
                else:
                    tot_freq=0
                col_freq_dict[source_col]=tot_freq
        
            col_list=sorted(col_freq_dict.items(), key=lambda x: x[1], reverse=True)
        
            for i in range(1,settings.replication_factor+1):
                if i > len(col_list):
                    break
                hypergraph.hyper_nodes[table].append(col_list[i-1][0])

        for table in hypergraph.hyper_nodes:
            for col in hypergraph.hyper_nodes[table]:
                tree_opt_time+=context.joinTreer.candidate_trees[table][col].build_time
    
    total_opt_time.append(tree_opt_time)
    final_result, query_costs, query_ratio, group_opt_time = processing_join_workload(hypergraph, group_type, strategy, tree_type='pac')
    # total_opt_time+=group_opt_time
    print("Final Column Selection Cost of PAC-Tree is: ",sum(list(query_costs.values()))/len(query_costs.keys()))
    print("Average Data Cost Ratio of PAC-Tree is: ",sum(list(query_ratio.values()))/len(query_ratio.keys()))
    print("Total Opt Time (tree,order,group) of PAC-Tree is: ",total_opt_time)
    hypergraph.show_selections()
    hypergraph_save_path=f'{base_dir}/../layouts/bs={settings.block_size}/{hypergraph.benchmark}/join_key_selection/jt-hgraph.pkl'
    os.makedirs(os.path.dirname(hypergraph_save_path), exist_ok=True)
    with open(hypergraph_save_path,'wb') as f:
        pickle.dump(hypergraph,f)
    return final_result, query_costs, query_ratio, sum(total_opt_time)

def save_experiment_result(experiment_result,disabled_prim_reorder):
    import pandas as pd
    import os

    base_dir = os.path.dirname(os.path.abspath(__file__))
    result_dir = f"{base_dir}/../experiment/result/sf={settings.scale_factor}/bs={settings.block_size}/rep={settings.replication_factor}/"
    os.makedirs(result_dir, exist_ok=True)
    query_list = []
    table_list = []
    time_overhead_list = []
    for model_name, (table_info, query_info, query_ratio_info, opt_time) in experiment_result.items():
        for q_id, cost_val in query_info.items():
            query_list.append([model_name, q_id, cost_val,query_ratio_info[q_id]])
        for table, cost_dict in table_info.items():
            shuffle_times = cost_dict.get('shuffle times', 0)
            hyper_times = cost_dict.get('hyper times', 0)
            join_cost = cost_dict.get('join cost', 0)
            table_list.append([model_name, table, shuffle_times, hyper_times, join_cost])
        time_overhead_list.append([model_name, opt_time])

    df_query = pd.DataFrame(query_list, columns=["Model", "Query ID", "Join Cost", "Scan Ratio"])
    df_table = pd.DataFrame(table_list, columns=["Model", "Table", "Shuffle Times", "Hyper Times", "Join Cost"])
    df_time = pd.DataFrame(time_overhead_list, columns=["Model", "Opt Time"])

    pivot_query_cost = df_query.pivot_table(index="Query ID", columns="Model", values="Join Cost", aggfunc="first").reset_index()
    pivot_ratio = df_query.pivot_table(index="Query ID", columns="Model", values="Scan Ratio", aggfunc="first").reset_index()
    pivot_query_cost.rename(columns={
        "PAW": "PAW_Join Cost",
        "ADP": "AD-MTO_Join Cost",
        "JT": "PAC-Tree_Join Cost"
    }, inplace=True)
    
    pivot_ratio.rename(columns={
        "PAW": "PAW_Scan Ratio",
        "ADP": "AD-MTO_Scan Ratio",
        "JT": "PAC-Tree_Scan Ratio"
    }, inplace=True)
    merged_query_df = pd.merge(pivot_query_cost, pivot_ratio, on="Query ID")
    
    
    pivot_shuffle = df_table.pivot_table(index="Table", columns="Model", values="Shuffle Times", aggfunc="first").reset_index()
    pivot_hyper = df_table.pivot_table(index="Table", columns="Model", values="Hyper Times", aggfunc="first").reset_index()
    pivot_cost = df_table.pivot_table(index="Table", columns="Model", values="Join Cost", aggfunc="first").reset_index()
    pivot_shuffle.rename(
        columns={
            "PAW": "PAW_shuffle_times",
            "ADP": "AD-MTO_shuffle_times",
            "JT": "PAC-Tree_shuffle_times",
        },
        inplace=True
    )
    pivot_hyper.rename(
        columns={
            "PAW": "PAW_hyper_times",
            "ADP": "AD-MTO_hyper_times",
            "JT": "PAC-Tree_hyper_times"
        },
        inplace=True
    )
    merged_df = pivot_shuffle.merge(pivot_hyper, on="Table")
    merged_df = merged_df.merge(pivot_cost, on="Table", suffixes=(None, "_cost"))
    merged_df.rename(
        columns={
            "PAW": "PAW_Scan Block",
            "ADP": "AD-MTO_Scan Block",
            "JT": "PAC-Tree_Scan Block",
        },
        inplace=True
    ) 
    os.makedirs(result_dir, exist_ok=True)
    res_saved_path=f"{result_dir}/{settings.benchmark}_results.xlsx"
    with pd.ExcelWriter(res_saved_path) as writer:
        merged_query_df.to_excel(writer, sheet_name="Query_Cost", index=False)
        merged_df.to_excel(writer, sheet_name="Table_Cost", index=False)
        df_time.to_excel(writer, sheet_name="Time_Cost", index=False)


def pretty_print_dubgger_query_join_info():
    from rich.console import Console
    from rich.table import Table
    from rich.panel import Panel
    
    console = Console()
    benchmarks = list(dubug_logger.query_cost_distributions.keys())
    for benchmark in benchmarks:
        table = Table(show_header=True, header_style="bold cyan")
        table.add_column("QueryID", style="bold magenta")
        methods = sorted(list(dubug_logger.query_cost_distributions[benchmark].keys()))
        for method in methods:
            table.add_column(f"{method} Ratio", justify="right", style=get_method_style(method))
            table.add_column(f"{method} Hyper", justify="right", style=get_method_style(method))
            table.add_column(f"{method} Shuffle", justify="right", style=get_method_style(method))
        all_qids = set()
        for method in methods:
            all_qids.update(dubug_logger.query_cost_distributions[benchmark][method].keys())
        for qid in sorted(list(all_qids)):
            row = [str(qid)]
            for method in methods:
                if qid in dubug_logger.query_cost_distributions[benchmark][method]:
                    scan_ratio,hyper_cost, shuffle_cost = dubug_logger.query_cost_distributions[benchmark][method][qid]
                    row.append(str(scan_ratio))
                    row.append(hyper_cost)
                    row.append(shuffle_cost)
                else:
                    row.append("N/A")
                    row.append("N/A")
                    row.append("N/A")
            table.add_row(*row)
        panel = Panel(table, 
                     title=f"[bold yellow]Query Cost Distributions for {benchmark}[/bold yellow]", 
                     border_style="blue", 
                     expand=False)
        console.print(panel)
        console.print("\n")

def get_method_style(method):
    """Return the color style for a given method."""
    styles = {
        "paw": "green",
        "mto": "blue",
        "pac": "yellow",
        "adp": "cyan"
    }
    return styles.get(method.lower(), "white")


def pretty_print_dubgger_table_join_info():
    from rich.console import Console
    from rich.table import Table
    console = Console()
    table = Table(show_header=True, header_style="bold cyan")
    benchmarks = list(dubug_logger.join_cost_distributions.keys())
    table.add_column("Benchmark", style="bold magenta")
    table.add_column("Table", style="dim")
    all_methods = set()
    for benchmark in benchmarks:
        all_methods.update(dubug_logger.join_cost_distributions[benchmark].keys())
    method_list = sorted(list(all_methods))  # Sort to ensure consistent order
    for method in method_list:
        table.add_column(method, justify="right", style=get_method_style(method))
    for benchmark in benchmarks:
        all_tables = set()
        benchmark_methods = dubug_logger.join_cost_distributions[benchmark].keys()
        for method in benchmark_methods:
            all_tables.update(dubug_logger.join_cost_distributions[benchmark][method].keys())
        primary_method = next(iter(benchmark_methods))
        sorted_tables = sorted(
            [t for t in all_tables if t != "avg"],
            key=lambda t: dubug_logger.join_cost_distributions[benchmark][primary_method].get(t, 0),
            reverse=True
        )
        if "avg" in all_tables:
            sorted_tables.append("avg")
        for i, table_name in enumerate(sorted_tables):
            row = []
            if i == 0:
                row.append(benchmark)
            else:
                row.append("")
            row.append(table_name)
            for method in method_list:
                if method in benchmark_methods:
                    cost = dubug_logger.join_cost_distributions[benchmark][method].get(table_name, "N/A")
                    row.append(f"{cost:,}" if isinstance(cost, (int, float)) else cost)
                else:
                    row.append("N/A")
            if table_name == "avg":
                table.add_row(*row, style="bold")
            else:
                table.add_row(*row)
    console.print("\n[bold yellow]Join Cost Distributions By Method[/bold yellow]\n")
    console.print(table)

def get_method_style(method):
    styles = {
        "paw": "green",
        "mto": "blue",
        "pac": "yellow",
        "adp": "cyan"
    }
    return styles.get(method.lower(), "white")


def run_evaluation(save_result=False, disabled_prim_reorder=False):
    """
    disabled_prim_reorder: 是否使用 基于prim算法的join reorder。若为True,则不使用join reorder
    """
    load_tree_context()
    
    experiment_data = {}
    # logging.basicConfig(filename=f'{base_dir}/../logs/column_selection_experiment.log',
    #                     level=logging.INFO,
    #                     format='%(asctime)s - %(levelname)s - %(message)s')

    using_join_order_strategy='traditional' if disabled_prim_reorder else 'prim'
    def run_PAW():
        res_PAW, cost_PAW, ratio_PAW, total_opt_time = select_columns_by_PAW(strategy='traditional', group_type=0) 
        return [res_PAW, cost_PAW, ratio_PAW, total_opt_time]
    
    def run_AD_MTO():
        res_ADP, cost_ADP, ratio_ADP, total_opt_time = select_columns_by_AD_MTO(strategy='traditional', group_type=0)
        return [res_ADP, cost_ADP, ratio_ADP, total_opt_time]
        
    def run_JT():
        res_JT, cost_JT,ratio_JT, total_opt_time = select_columns_by_JT(group_type=1, strategy=using_join_order_strategy, disabled_joinkey_selection=False) 
        return [res_JT, cost_JT, ratio_JT, total_opt_time]
    
    experiment_data["PAW"]=run_PAW()
    experiment_data["ADP"]=run_AD_MTO()
    experiment_data["JT"]=run_JT()
    
    if args.mode=='debug':
        pretty_print_dubgger_table_join_info()
        pretty_print_dubgger_query_join_info()
    
    # exit(-1)
    with concurrent.futures.ThreadPoolExecutor() as executor:
        future_paw = executor.submit(run_PAW)
        future_adp = executor.submit(run_AD_MTO)
        future_jt  = executor.submit(run_JT)

        result_paw = future_paw.result()
        result_adp = future_adp.result()
        result_jt  = future_jt.result()
        
        experiment_data["PAW"]=result_paw
        experiment_data["ADP"]=result_adp
        experiment_data["JT"]=result_jt
    print(experiment_data)
    
    if save_result:
        save_experiment_result(experiment_data, disabled_prim_reorder) 
        
        

def scaling_block_size():
    for item in [1000, 5000, 20000, 50000]:
        settings.block_size=item
        run_evaluation(save_result=False, disabled_prim_reorder=True)


def scaling_replication():
    for item in [1,2,3,4]:
        settings.replication_factor=item
        run_evaluation(save_result=False, disabled_prim_reorder=False)
        

def test_model_overhead_by_scaling_scale():
    for sf in [5,10,20,50]:
        settings.scale_factor=sf
        # settings.block_size=int(10000/sf)
        run_evaluation(save_result=False, disabled_prim_reorder=False)


if args.command == 0:  #basic experiments
    settings.replication_factor=3 #defult
    run_evaluation(save_result=False, disabled_prim_reorder=False)
elif args.command == 1: # sensitivity analysis
    scaling_block_size()
elif args.command == 2: 
    scaling_replication()
elif args.command == 3: 
    test_model_overhead_by_scaling_scale()
else:
    print("Invalid command")