PAC-tree / db / query_tooler.py
query_tooler.py
Raw
import re
from connector import Connector
import argparse
import os
import copy
import pickle
from collections import defaultdict
import random

parser = argparse.ArgumentParser()
parser.add_argument('--benchmark', type=str, default='tpch', help='The benchmark to use (tpch, imdb, or join-order)')
parser.add_argument('--format', action='store_true', help='Whether to format queries')
parser.add_argument('--export_mto_paw_queries', action='store_true', help='Whether to export paw and mto queries')
parser.add_argument('--export_pac_queries', action='store_true', help='Whether to export mto queries')


args=parser.parse_args()

current_path=os.path.dirname(os.path.abspath(__file__))

benchmark = args.benchmark
query_base_dir=current_path+'/../queryset/'+benchmark
dataset_base_dir=current_path+'/../dataset/'+benchmark



def formatting_query_seed():
    # Part 2: For each query, get scanned rows and calculate accessed column ranges
    connector = Connector(benchmark)
    connector.connect()
    
    schema_details=pickle.load(open(f'{dataset_base_dir}/metadata.pkl', 'rb'))
    
    queries,query_column_ranges = [],{}

    if os.path.exists(f'{query_base_dir}/seed1.txt'):
        
        # Read the content of seed1.txt
        with open(f'{query_base_dir}/seed1.txt', 'r') as file:
            content = file.read()

        # Extract queries from the content
        pattern = r'(?:^|\n)(?:select|create view).*?;(?:\n|$)'
        queries = re.findall(pattern, content, re.DOTALL | re.IGNORECASE)
    elif os.path.exists(f'{query_base_dir}/seed1'):
        # Traverse all SQL files in the directory and store them in the queries variable
        for file in os.listdir(f'{query_base_dir}/seed1'):
            if file.endswith('.sql'):
                with open(f'{query_base_dir}/seed1/{file}', 'r') as f:
                    queries.append(f.read())

    for i, query in enumerate(queries, 1):
        # Remove comments
        query_clean = re.sub(r'--.*', '', query)
        query_clean = query_clean.strip()
        
        # Skip empty queries
        if not query_clean:
            continue

        # if i!=8:
        #     continue

        print(f"\nProcessing Query {i}:\n{query_clean}\n")
        
        # Use EXPLAIN (FORMAT JSON) to get the query plan
        explain_query = f"EXPLAIN (FORMAT JSON) {query_clean}"
        explain_result = connector.execute_query(explain_query)
        # The result is a list of tuples; get the first element
        plan_json = explain_result[0][0]
        # Parse the JSON plan
        plan = plan_json[0]['Plan']

        # Function to remove alias prefixes from predicates
        def remove_alias_prefix(pred):
            return re.sub(r'\b\w+\.', '', pred)

        # Function to filter out predicates with parameter placeholders
        def filter_predicates(preds):
            return [pred for pred in preds if not re.search(r'\$\d+', pred) and pred.strip()]

        # Function to replace aliases in predicates
        def replace_aliases(pred, aliases):
            for alias in aliases:
                pred = re.sub(r'\b{}\.'.format(re.escape(alias)), '', pred)
            return pred

        # Function to extract and process predicates
        def process_predicates(table, preds, aliases):
            filtered_preds = []
            for pred in preds:
                pred = replace_aliases(pred, aliases)
                if not re.search(r'\$\d+', pred) and pred.strip():
                    filtered_preds.append(pred)
            return filtered_preds

        def if_equal_pred(pred):
            pred=pred.strip().replace('(','').replace(')','').replace(' ','')
            if len(pred.split('='))!=2:
                return False
            col1,col2=pred.split('=')  # Check if col1 and col2 are column names without '.'
            # Check if col1 and col2 are column names based on schema_details
            if_col1,if_col2=False,False
            for table in schema_details:
                for col in schema_details[table]['numeric_columns']+schema_details[table]['text_columns']:
                    if col1==col:
                        if_col1=True
                    if col2==col:
                        if_col2=True
            if if_col1 and if_col2:
                return True
            return False

        # Function to extract tables, aliases, and predicates from the plan
        def extract_info(plan_node, tables_set, aliases_map, table_predicates):
            current_table = None
            current_alias = None
            # Get the relation name (table name) and alias if available
            if 'Relation Name' in plan_node:
                current_table = plan_node['Relation Name']
                tables_set.add(current_table)
                if 'Alias' in plan_node:
                    current_alias = plan_node['Alias']
                    aliases_map[current_alias.lower()] = current_table
            # Get the filter (predicate) if available
            if 'Filter' in plan_node and (current_table or current_alias):
                filter_text = plan_node['Filter']
                # Ignore internal predicates like '(NOT (hashed SubPlan 1))'
                if 'SubPlan' not in filter_text:
                    # Assign predicate to the current_table
                    if current_table:
                        # Remove alias prefix from predicates
                        pred = remove_alias_prefix(filter_text)
                        # Skip predicates containing parameter placeholders like $0  or two column condtions like a=b
                        if not re.search(r'\$\d+', pred) and not if_equal_pred(pred) and pred.strip():
                            table_predicates[current_table].append(pred)
            # Replace table aliases in predicates
            for table, preds in table_predicates.items():
                aliases = [alias for alias, tbl in aliases_map.items() if tbl == table]
                table_predicates[table] = process_predicates(table, preds, aliases)
            if 'Plans' in plan_node:
                for subplan in plan_node['Plans']:
                    extract_info(subplan, tables_set, aliases_map, table_predicates)

        def parse_join_condition(cond):
            """
            Extract specific columns from conditions like '(part.p_partkey = partsupp.ps_partkey)', 
            ignoring content like (SubPlan x)=xxx.
            Returns a list like ['p_partkey', 'ps_partkey']
            """
            import re
            join_keys = []
            if not cond:
                return join_keys
            
            # Remove '(SubPlan x) = ...' parts
            cond = re.sub(r'\(SubPlan\s+\d+\)\s*=\s*\S+', '', cond)

            # Remove outer parentheses and extra whitespace
            text = cond.strip('() ')

            # Find patterns like 'table.col = alias.col' or 'col = table.col'
            pattern = r'((?:\w+\.)?\w+)\s*=\s*((?:\w+\.)?\w+)'
            matches = re.findall(pattern, text)
    
            for left, right in matches:
                left_parts = left.split('.') if '.' in left else [None, left]
                right_parts = right.split('.') if '.' in right else [None, right]
                
                left_table = left_parts[0]
                left_col = left_parts[-1]
                right_table = right_parts[0]
                right_col = right_parts[-1]
                
                # Validate column names
                left_valid = is_valid_column(left_col, left_table, schema_details)
                right_valid = is_valid_column(right_col, right_table, schema_details)
                
                if not (left_valid and right_valid):
                    continue  # Skip invalid columns
                
                # Handle case where both sides have table names
                if left_table and right_table:
                    # Handle alias mappings
                    if left_table in aliases_map:
                        left_table = aliases_map[left_table]
                    if right_table in aliases_map:
                        right_table = aliases_map[right_table]
                    join_keys.append(f"{left_table}.{left_col}={right_table}.{right_col}")
                
                # Handle case where left side has no table name
                elif not left_table and right_table:
                    # Handle alias mappings
                    if right_table in aliases_map:
                        right_table = aliases_map[right_table]
                    
                    # Find table for left column
                    left_table = find_table_for_column(left_col, tables_set, schema_details)
                    if left_table:
                        join_keys.append(f"{left_table}.{left_col}={right_table}.{right_col}")
                
                # Handle case where right side has no table name
                elif left_table and not right_table:
                    # Handle alias mappings
                    if left_table in aliases_map:
                        left_table = aliases_map[left_table]
                    
                    # Find table for right column
                    right_table = find_table_for_column(right_col, tables_set, schema_details)
                    if right_table:
                        join_keys.append(f"{left_table}.{left_col}={right_table}.{right_col}")
            
            return join_keys

        
        def is_valid_column(col_name, table_name=None, schema_details=None):
            """
            Validate if the column name exists in the schema
            """
            if not schema_details:
                return True  # If no schema information is provided, default to valid
                
            # If table name is provided, only check that table
            if table_name:
                if table_name in schema_details:
                    numeric_cols = schema_details[table_name].get('numeric_columns', [])
                    text_cols = schema_details[table_name].get('text_columns', [])
                    return col_name in numeric_cols or col_name in text_cols
                return False  # Table does not exist in schema
                
            # If no table name is provided, check all tables
            for table in schema_details:
                numeric_cols = schema_details[table].get('numeric_columns', [])
                text_cols = schema_details[table].get('text_columns', [])
                if col_name in numeric_cols or col_name in text_cols:
                    return True
                    
            return False  # Column name does not exist in any table
        

        def find_table_for_column(column, tables_set, schema_details):
            """
            Find the table that owns the given column
            """
            candidate_tables = []
            for table in tables_set:
                if table in schema_details:
                    numeric_cols = schema_details[table].get('numeric_columns', [])
                    text_cols = schema_details[table].get('text_columns', [])
                    
                    if column in numeric_cols or column in text_cols:
                        candidate_tables.append(table)
            
            # If only one table is found, return it
            if len(candidate_tables) == 1:
                return candidate_tables[0]
            
            # If multiple tables are found, select the most likely one based on context
            # Here we simply return the first one, but in practice more sophisticated logic might be needed
            elif len(candidate_tables) > 1:
                return candidate_tables[0]
            
            return None

        def extract_joins_from_plan(plan_node, join_info, visited_tables,table_predicates):
            """
            Recursively extract Join information, recording actually used tables and join_keys array
            """
            node_type = plan_node.get('Node Type', '')
            if 'Join' in node_type:
                cond = plan_node.get('Hash Cond') or plan_node.get('Merge Cond') or plan_node.get('Join Filter')
                join_keys = parse_join_condition(cond)
                if join_keys:
                    for key in join_keys:
                        join_info.append({
                            'join_type': node_type,
                            'join_keys': [key]
                        })
            if 'Join Filter' in plan_node:
                join_keys = parse_join_condition(plan_node['Join Filter'])
                for key in join_keys:
                    join_info.append({
                        'join_type': node_type,
                        'join_keys': [key]
                    })
            if 'Index Cond' in plan_node:
                index_conds=plan_node['Index Cond'].split(' AND ')
                for index_cond in index_conds:
                    index_cond=index_cond.replace('((','(').replace('))',')')
                    if '.' not in index_cond:
                        table_predicates[plan_node['Relation Name']].append(index_cond)
                    else:
                        join_keys = parse_join_condition(index_cond)
                        if join_keys:
                            join_info.append({
                                'join_type': 'Index Cond',
                                'join_keys': join_keys
                            })
            # Continue to child nodes
            if 'Plans' in plan_node:
                for child in plan_node['Plans']:
                    extract_joins_from_plan(child, join_info, visited_tables,table_predicates)
        
        
        # Sets to store tables, aliases, and a dictionary for table-specific predicates
        tables_set = set()
        aliases_map = {}
        table_predicates = defaultdict(list)
        # Extract information from the plan
        extract_info(plan, tables_set, aliases_map, table_predicates)
        table_list = list(tables_set)


        join_info = []
        visited = set()
        extract_joins_from_plan(plan, join_info, visited,table_predicates)
        
        # Build and execute SQL queries to get min and max for numeric columns
        ranges = {}
        for table in table_list:
            if table not in schema_details:
                continue
            numeric_cols = schema_details[table]['numeric_columns']
            if not numeric_cols:
                continue
            # Retrieve predicates specific to the current table
            if table not in table_predicates:
                continue
            predicates = table_predicates.get(table, [])
            where_clause = ' AND '.join(predicates)

            # Construct MIN and MAX queries for each numeric column
            min_max_selects = [f"MIN({col}) as min_{col}, MAX({col}) as max_{col}" for col in numeric_cols]
            
            min_max_query=f"SELECT {', '.join(min_max_selects)} FROM {table} WHERE {where_clause};"
            
            print(f"Executing min/max query on table '{table}':\n{min_max_query}\n")


            # Execute the query
            try:
                min_max_result = connector.execute_query(min_max_query)
                if min_max_result[0][0]:
                    col_ranges = {}
                    result_row = min_max_result[0]
                    num_cols = len(numeric_cols)
                    for idx, col in enumerate(numeric_cols):
                        min_val = result_row[idx * 2]
                        max_val = result_row[idx * 2 + 1]
                        if min_val is not None and max_val is not None:
                            col_ranges[col] = {'min': min_val, 'max': max_val}
                    ranges[table] = col_ranges
            except Exception as e:
                print(f"Error executing min/max query: {e} !!!!!!")

        
        
        if not ranges and not join_info:
            continue
        query_column_ranges[f'Query {i}'] = {
            'ranges': ranges,
            'join_info': join_info
        }
    # Save query_column_ranges results to .pkl file
    with open(f'{query_base_dir}/formatted_queries.pkl', 'wb') as f:
        pickle.dump(query_column_ranges, f)
    # Also write a copy to txt file
    with open(f'{query_base_dir}/formatted_queries.txt', 'w') as f:
        for query_name, tables in query_column_ranges.items():
            f.write(f"{query_name}:\n")
            for table, cols in tables['ranges'].items():
                f.write(f"  Table: {table}\n")
                for col, rng in cols.items():
                    f.write(f"    {col}: min={rng['min']}, max={rng['max']}\n")
            for join in tables['join_info']:
                f.write(f"  Join Type: {join['join_type']}\n")
                f.write(f"  Join Keys: {join['join_keys']}\n")
            f.write("\n")

    connector.close_all_connection()


def validate_queries(method):
    saved_queries=pickle.load(open(f'{query_base_dir}/formatted_queries.pkl', 'rb'))
    schema_details=pickle.load(open(f'{dataset_base_dir}/metadata.pkl', 'rb'))
    
    # solved_queries={}
    for query_name, query_item in saved_queries.items():
        # if query_name!='Query 20':
        #     continue
        
        # Step 1: Process join-induced filters
        cnt=0
        while cnt<3: 
            for join_item in query_item['join_info']:
                # new_join_keys=[]
                # Need to loop twice to avoid missing some tables
                left_table, left_col, right_table, right_col=join_item['join_keys'][0].replace('=','.').split('.')

                # check if left_table and right_table are in the same join
                if left_table in query_item['ranges'] and right_table in query_item['ranges']:
                    if left_col in query_item['ranges'][left_table] and right_col in query_item['ranges'][right_table]:
                        left_col_ranges=query_item['ranges'][left_table][left_col]
                        right_col_ranges=query_item['ranges'][right_table][right_col]
                        query_item['ranges'][left_table][left_col]['min']=max(left_col_ranges['min'], right_col_ranges['min'])
                        query_item['ranges'][right_table][right_col]['min']=max(left_col_ranges['min'], right_col_ranges['min'])
                        query_item['ranges'][left_table][left_col]['max']=min(left_col_ranges['max'], right_col_ranges['max'])
                        query_item['ranges'][right_table][right_col]['max']=min(left_col_ranges['max'], right_col_ranges['max'])
                    elif left_col in query_item['ranges'][left_table]:
                        left_col_ranges=query_item['ranges'][left_table][left_col]
                        query_item['ranges'][right_table][right_col]=left_col_ranges
                    elif right_col in query_item['ranges'][right_table]:
                        right_col_ranges=query_item['ranges'][right_table][right_col]
                        query_item['ranges'][left_table][right_col]=right_col_ranges
                
                elif left_table in query_item['ranges'] and right_table not in query_item['ranges']:
                    if left_col in query_item['ranges'][left_table]:
                        left_col_ranges=query_item['ranges'][left_table][left_col]
                        query_item['ranges'][right_table]={
                            right_col:left_col_ranges
                        }
                        for col in schema_details[right_table]['numeric_columns']:
                            if col!=right_col:
                                query_item['ranges'][right_table][col]=schema_details[right_table]['ranges'][col]
                        
                elif left_table not in query_item['ranges'] and right_table in query_item['ranges']:
                    if right_col in query_item['ranges'][right_table]:
                        right_col_ranges=query_item['ranges'][right_table][right_col]
                        query_item['ranges'][left_table]={
                            left_col:right_col_ranges
                        }
                        for col in schema_details[left_table]['numeric_columns']:
                            if col!=left_col:
                                query_item['ranges'][left_table][col]=schema_details[left_table]['ranges'][col]
                else:
                    continue    
            cnt+=1

        new_join_items=[]
        join_record=dict()
        for join_item in query_item['join_info']:
            
            left_table_col, right_table_col=join_item['join_keys'][0].split('=')
            left_table, left_col=left_table_col.split('.')
            right_table, right_col=right_table_col.split('.')
            if left_table_col==right_table_col:
                continue
            if left_table in join_record and join_record[left_table]==right_table:
                continue

            join_record[left_table]=right_table
            join_record[right_table]=left_table

            new_join_items.append(join_item)
        if method=="mto":
            random.shuffle(new_join_items)
        query_item['join_info']=new_join_items
        
        ranges=copy.deepcopy(query_item['ranges'])     
        for table, col_ranges in ranges.items():
            new_col_ranges={}
            for col, rng in col_ranges.items():
                if rng['min']!=schema_details[table]['ranges'][col]['min'] or rng['max']!=schema_details[table]['ranges'][col]['max']:
                    new_col_ranges[col]=rng
            query_item['ranges'][table]=new_col_ranges
            
            
    with open(f'{query_base_dir}/{method}_queries.pkl', 'wb') as f:
        pickle.dump(saved_queries, f)

def paw_queries():
    formatted_queries=pickle.load(open(f'{query_base_dir}/formatted_queries.pkl', 'rb'))
    verified_queries=pickle.load(open(f'{query_base_dir}/mto_queries.pkl', 'rb'))
    paw_queries={}
    for query_name, query_item in formatted_queries.items():
        paw_queries[query_name]={
            'ranges':query_item['ranges'],
            'join_info':verified_queries[query_name]['join_info']
        }

    with open(f'{query_base_dir}/paw_queries.pkl', 'wb') as f:
        pickle.dump(paw_queries, f)

args.format=True
args.export_mto_paw_queries=True
args.export_pac_queries=True

if args.format:
    formatting_query_seed()

if args.export_mto_paw_queries:
    # This function not only validates but also processes queries, such as handling join-induced filters
    # (which is crucial for subsequent join order algorithms. Traditional join-induced predicates help accurately
    # estimate potential join costs for each table and build better join trees)
    validate_queries(method="mto")
    paw_queries()

if args.export_pac_queries:
    validate_queries(method="pac")