import re
from connector import Connector
import argparse
import os
import copy
import pickle
from collections import defaultdict
import random
parser = argparse.ArgumentParser()
parser.add_argument('--benchmark', type=str, default='tpch', help='The benchmark to use (tpch, imdb, or join-order)')
parser.add_argument('--format', action='store_true', help='Whether to format queries')
parser.add_argument('--export_mto_paw_queries', action='store_true', help='Whether to export paw and mto queries')
parser.add_argument('--export_pac_queries', action='store_true', help='Whether to export mto queries')
args=parser.parse_args()
current_path=os.path.dirname(os.path.abspath(__file__))
benchmark = args.benchmark
query_base_dir=current_path+'/../queryset/'+benchmark
dataset_base_dir=current_path+'/../dataset/'+benchmark
def formatting_query_seed():
# Part 2: For each query, get scanned rows and calculate accessed column ranges
connector = Connector(benchmark)
connector.connect()
schema_details=pickle.load(open(f'{dataset_base_dir}/metadata.pkl', 'rb'))
queries,query_column_ranges = [],{}
if os.path.exists(f'{query_base_dir}/seed1.txt'):
# Read the content of seed1.txt
with open(f'{query_base_dir}/seed1.txt', 'r') as file:
content = file.read()
# Extract queries from the content
pattern = r'(?:^|\n)(?:select|create view).*?;(?:\n|$)'
queries = re.findall(pattern, content, re.DOTALL | re.IGNORECASE)
elif os.path.exists(f'{query_base_dir}/seed1'):
# Traverse all SQL files in the directory and store them in the queries variable
for file in os.listdir(f'{query_base_dir}/seed1'):
if file.endswith('.sql'):
with open(f'{query_base_dir}/seed1/{file}', 'r') as f:
queries.append(f.read())
for i, query in enumerate(queries, 1):
# Remove comments
query_clean = re.sub(r'--.*', '', query)
query_clean = query_clean.strip()
# Skip empty queries
if not query_clean:
continue
# if i!=8:
# continue
print(f"\nProcessing Query {i}:\n{query_clean}\n")
# Use EXPLAIN (FORMAT JSON) to get the query plan
explain_query = f"EXPLAIN (FORMAT JSON) {query_clean}"
explain_result = connector.execute_query(explain_query)
# The result is a list of tuples; get the first element
plan_json = explain_result[0][0]
# Parse the JSON plan
plan = plan_json[0]['Plan']
# Function to remove alias prefixes from predicates
def remove_alias_prefix(pred):
return re.sub(r'\b\w+\.', '', pred)
# Function to filter out predicates with parameter placeholders
def filter_predicates(preds):
return [pred for pred in preds if not re.search(r'\$\d+', pred) and pred.strip()]
# Function to replace aliases in predicates
def replace_aliases(pred, aliases):
for alias in aliases:
pred = re.sub(r'\b{}\.'.format(re.escape(alias)), '', pred)
return pred
# Function to extract and process predicates
def process_predicates(table, preds, aliases):
filtered_preds = []
for pred in preds:
pred = replace_aliases(pred, aliases)
if not re.search(r'\$\d+', pred) and pred.strip():
filtered_preds.append(pred)
return filtered_preds
def if_equal_pred(pred):
pred=pred.strip().replace('(','').replace(')','').replace(' ','')
if len(pred.split('='))!=2:
return False
col1,col2=pred.split('=') # Check if col1 and col2 are column names without '.'
# Check if col1 and col2 are column names based on schema_details
if_col1,if_col2=False,False
for table in schema_details:
for col in schema_details[table]['numeric_columns']+schema_details[table]['text_columns']:
if col1==col:
if_col1=True
if col2==col:
if_col2=True
if if_col1 and if_col2:
return True
return False
# Function to extract tables, aliases, and predicates from the plan
def extract_info(plan_node, tables_set, aliases_map, table_predicates):
current_table = None
current_alias = None
# Get the relation name (table name) and alias if available
if 'Relation Name' in plan_node:
current_table = plan_node['Relation Name']
tables_set.add(current_table)
if 'Alias' in plan_node:
current_alias = plan_node['Alias']
aliases_map[current_alias.lower()] = current_table
# Get the filter (predicate) if available
if 'Filter' in plan_node and (current_table or current_alias):
filter_text = plan_node['Filter']
# Ignore internal predicates like '(NOT (hashed SubPlan 1))'
if 'SubPlan' not in filter_text:
# Assign predicate to the current_table
if current_table:
# Remove alias prefix from predicates
pred = remove_alias_prefix(filter_text)
# Skip predicates containing parameter placeholders like $0 or two column condtions like a=b
if not re.search(r'\$\d+', pred) and not if_equal_pred(pred) and pred.strip():
table_predicates[current_table].append(pred)
# Replace table aliases in predicates
for table, preds in table_predicates.items():
aliases = [alias for alias, tbl in aliases_map.items() if tbl == table]
table_predicates[table] = process_predicates(table, preds, aliases)
if 'Plans' in plan_node:
for subplan in plan_node['Plans']:
extract_info(subplan, tables_set, aliases_map, table_predicates)
def parse_join_condition(cond):
"""
Extract specific columns from conditions like '(part.p_partkey = partsupp.ps_partkey)',
ignoring content like (SubPlan x)=xxx.
Returns a list like ['p_partkey', 'ps_partkey']
"""
import re
join_keys = []
if not cond:
return join_keys
# Remove '(SubPlan x) = ...' parts
cond = re.sub(r'\(SubPlan\s+\d+\)\s*=\s*\S+', '', cond)
# Remove outer parentheses and extra whitespace
text = cond.strip('() ')
# Find patterns like 'table.col = alias.col' or 'col = table.col'
pattern = r'((?:\w+\.)?\w+)\s*=\s*((?:\w+\.)?\w+)'
matches = re.findall(pattern, text)
for left, right in matches:
left_parts = left.split('.') if '.' in left else [None, left]
right_parts = right.split('.') if '.' in right else [None, right]
left_table = left_parts[0]
left_col = left_parts[-1]
right_table = right_parts[0]
right_col = right_parts[-1]
# Validate column names
left_valid = is_valid_column(left_col, left_table, schema_details)
right_valid = is_valid_column(right_col, right_table, schema_details)
if not (left_valid and right_valid):
continue # Skip invalid columns
# Handle case where both sides have table names
if left_table and right_table:
# Handle alias mappings
if left_table in aliases_map:
left_table = aliases_map[left_table]
if right_table in aliases_map:
right_table = aliases_map[right_table]
join_keys.append(f"{left_table}.{left_col}={right_table}.{right_col}")
# Handle case where left side has no table name
elif not left_table and right_table:
# Handle alias mappings
if right_table in aliases_map:
right_table = aliases_map[right_table]
# Find table for left column
left_table = find_table_for_column(left_col, tables_set, schema_details)
if left_table:
join_keys.append(f"{left_table}.{left_col}={right_table}.{right_col}")
# Handle case where right side has no table name
elif left_table and not right_table:
# Handle alias mappings
if left_table in aliases_map:
left_table = aliases_map[left_table]
# Find table for right column
right_table = find_table_for_column(right_col, tables_set, schema_details)
if right_table:
join_keys.append(f"{left_table}.{left_col}={right_table}.{right_col}")
return join_keys
def is_valid_column(col_name, table_name=None, schema_details=None):
"""
Validate if the column name exists in the schema
"""
if not schema_details:
return True # If no schema information is provided, default to valid
# If table name is provided, only check that table
if table_name:
if table_name in schema_details:
numeric_cols = schema_details[table_name].get('numeric_columns', [])
text_cols = schema_details[table_name].get('text_columns', [])
return col_name in numeric_cols or col_name in text_cols
return False # Table does not exist in schema
# If no table name is provided, check all tables
for table in schema_details:
numeric_cols = schema_details[table].get('numeric_columns', [])
text_cols = schema_details[table].get('text_columns', [])
if col_name in numeric_cols or col_name in text_cols:
return True
return False # Column name does not exist in any table
def find_table_for_column(column, tables_set, schema_details):
"""
Find the table that owns the given column
"""
candidate_tables = []
for table in tables_set:
if table in schema_details:
numeric_cols = schema_details[table].get('numeric_columns', [])
text_cols = schema_details[table].get('text_columns', [])
if column in numeric_cols or column in text_cols:
candidate_tables.append(table)
# If only one table is found, return it
if len(candidate_tables) == 1:
return candidate_tables[0]
# If multiple tables are found, select the most likely one based on context
# Here we simply return the first one, but in practice more sophisticated logic might be needed
elif len(candidate_tables) > 1:
return candidate_tables[0]
return None
def extract_joins_from_plan(plan_node, join_info, visited_tables,table_predicates):
"""
Recursively extract Join information, recording actually used tables and join_keys array
"""
node_type = plan_node.get('Node Type', '')
if 'Join' in node_type:
cond = plan_node.get('Hash Cond') or plan_node.get('Merge Cond') or plan_node.get('Join Filter')
join_keys = parse_join_condition(cond)
if join_keys:
for key in join_keys:
join_info.append({
'join_type': node_type,
'join_keys': [key]
})
if 'Join Filter' in plan_node:
join_keys = parse_join_condition(plan_node['Join Filter'])
for key in join_keys:
join_info.append({
'join_type': node_type,
'join_keys': [key]
})
if 'Index Cond' in plan_node:
index_conds=plan_node['Index Cond'].split(' AND ')
for index_cond in index_conds:
index_cond=index_cond.replace('((','(').replace('))',')')
if '.' not in index_cond:
table_predicates[plan_node['Relation Name']].append(index_cond)
else:
join_keys = parse_join_condition(index_cond)
if join_keys:
join_info.append({
'join_type': 'Index Cond',
'join_keys': join_keys
})
# Continue to child nodes
if 'Plans' in plan_node:
for child in plan_node['Plans']:
extract_joins_from_plan(child, join_info, visited_tables,table_predicates)
# Sets to store tables, aliases, and a dictionary for table-specific predicates
tables_set = set()
aliases_map = {}
table_predicates = defaultdict(list)
# Extract information from the plan
extract_info(plan, tables_set, aliases_map, table_predicates)
table_list = list(tables_set)
join_info = []
visited = set()
extract_joins_from_plan(plan, join_info, visited,table_predicates)
# Build and execute SQL queries to get min and max for numeric columns
ranges = {}
for table in table_list:
if table not in schema_details:
continue
numeric_cols = schema_details[table]['numeric_columns']
if not numeric_cols:
continue
# Retrieve predicates specific to the current table
if table not in table_predicates:
continue
predicates = table_predicates.get(table, [])
where_clause = ' AND '.join(predicates)
# Construct MIN and MAX queries for each numeric column
min_max_selects = [f"MIN({col}) as min_{col}, MAX({col}) as max_{col}" for col in numeric_cols]
min_max_query=f"SELECT {', '.join(min_max_selects)} FROM {table} WHERE {where_clause};"
print(f"Executing min/max query on table '{table}':\n{min_max_query}\n")
# Execute the query
try:
min_max_result = connector.execute_query(min_max_query)
if min_max_result[0][0]:
col_ranges = {}
result_row = min_max_result[0]
num_cols = len(numeric_cols)
for idx, col in enumerate(numeric_cols):
min_val = result_row[idx * 2]
max_val = result_row[idx * 2 + 1]
if min_val is not None and max_val is not None:
col_ranges[col] = {'min': min_val, 'max': max_val}
ranges[table] = col_ranges
except Exception as e:
print(f"Error executing min/max query: {e} !!!!!!")
if not ranges and not join_info:
continue
query_column_ranges[f'Query {i}'] = {
'ranges': ranges,
'join_info': join_info
}
# Save query_column_ranges results to .pkl file
with open(f'{query_base_dir}/formatted_queries.pkl', 'wb') as f:
pickle.dump(query_column_ranges, f)
# Also write a copy to txt file
with open(f'{query_base_dir}/formatted_queries.txt', 'w') as f:
for query_name, tables in query_column_ranges.items():
f.write(f"{query_name}:\n")
for table, cols in tables['ranges'].items():
f.write(f" Table: {table}\n")
for col, rng in cols.items():
f.write(f" {col}: min={rng['min']}, max={rng['max']}\n")
for join in tables['join_info']:
f.write(f" Join Type: {join['join_type']}\n")
f.write(f" Join Keys: {join['join_keys']}\n")
f.write("\n")
connector.close_all_connection()
def validate_queries(method):
saved_queries=pickle.load(open(f'{query_base_dir}/formatted_queries.pkl', 'rb'))
schema_details=pickle.load(open(f'{dataset_base_dir}/metadata.pkl', 'rb'))
# solved_queries={}
for query_name, query_item in saved_queries.items():
# if query_name!='Query 20':
# continue
# Step 1: Process join-induced filters
cnt=0
while cnt<3:
for join_item in query_item['join_info']:
# new_join_keys=[]
# Need to loop twice to avoid missing some tables
left_table, left_col, right_table, right_col=join_item['join_keys'][0].replace('=','.').split('.')
# check if left_table and right_table are in the same join
if left_table in query_item['ranges'] and right_table in query_item['ranges']:
if left_col in query_item['ranges'][left_table] and right_col in query_item['ranges'][right_table]:
left_col_ranges=query_item['ranges'][left_table][left_col]
right_col_ranges=query_item['ranges'][right_table][right_col]
query_item['ranges'][left_table][left_col]['min']=max(left_col_ranges['min'], right_col_ranges['min'])
query_item['ranges'][right_table][right_col]['min']=max(left_col_ranges['min'], right_col_ranges['min'])
query_item['ranges'][left_table][left_col]['max']=min(left_col_ranges['max'], right_col_ranges['max'])
query_item['ranges'][right_table][right_col]['max']=min(left_col_ranges['max'], right_col_ranges['max'])
elif left_col in query_item['ranges'][left_table]:
left_col_ranges=query_item['ranges'][left_table][left_col]
query_item['ranges'][right_table][right_col]=left_col_ranges
elif right_col in query_item['ranges'][right_table]:
right_col_ranges=query_item['ranges'][right_table][right_col]
query_item['ranges'][left_table][right_col]=right_col_ranges
elif left_table in query_item['ranges'] and right_table not in query_item['ranges']:
if left_col in query_item['ranges'][left_table]:
left_col_ranges=query_item['ranges'][left_table][left_col]
query_item['ranges'][right_table]={
right_col:left_col_ranges
}
for col in schema_details[right_table]['numeric_columns']:
if col!=right_col:
query_item['ranges'][right_table][col]=schema_details[right_table]['ranges'][col]
elif left_table not in query_item['ranges'] and right_table in query_item['ranges']:
if right_col in query_item['ranges'][right_table]:
right_col_ranges=query_item['ranges'][right_table][right_col]
query_item['ranges'][left_table]={
left_col:right_col_ranges
}
for col in schema_details[left_table]['numeric_columns']:
if col!=left_col:
query_item['ranges'][left_table][col]=schema_details[left_table]['ranges'][col]
else:
continue
cnt+=1
new_join_items=[]
join_record=dict()
for join_item in query_item['join_info']:
left_table_col, right_table_col=join_item['join_keys'][0].split('=')
left_table, left_col=left_table_col.split('.')
right_table, right_col=right_table_col.split('.')
if left_table_col==right_table_col:
continue
if left_table in join_record and join_record[left_table]==right_table:
continue
join_record[left_table]=right_table
join_record[right_table]=left_table
new_join_items.append(join_item)
if method=="mto":
random.shuffle(new_join_items)
query_item['join_info']=new_join_items
ranges=copy.deepcopy(query_item['ranges'])
for table, col_ranges in ranges.items():
new_col_ranges={}
for col, rng in col_ranges.items():
if rng['min']!=schema_details[table]['ranges'][col]['min'] or rng['max']!=schema_details[table]['ranges'][col]['max']:
new_col_ranges[col]=rng
query_item['ranges'][table]=new_col_ranges
with open(f'{query_base_dir}/{method}_queries.pkl', 'wb') as f:
pickle.dump(saved_queries, f)
def paw_queries():
formatted_queries=pickle.load(open(f'{query_base_dir}/formatted_queries.pkl', 'rb'))
verified_queries=pickle.load(open(f'{query_base_dir}/mto_queries.pkl', 'rb'))
paw_queries={}
for query_name, query_item in formatted_queries.items():
paw_queries[query_name]={
'ranges':query_item['ranges'],
'join_info':verified_queries[query_name]['join_info']
}
with open(f'{query_base_dir}/paw_queries.pkl', 'wb') as f:
pickle.dump(paw_queries, f)
args.format=True
args.export_mto_paw_queries=True
args.export_pac_queries=True
if args.format:
formatting_query_seed()
if args.export_mto_paw_queries:
# This function not only validates but also processes queries, such as handling join-induced filters
# (which is crucial for subsequent join order algorithms. Traditional join-induced predicates help accurately
# estimate potential join costs for each table and build better join trees)
validate_queries(method="mto")
paw_queries()
if args.export_pac_queries:
validate_queries(method="pac")