################################ # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import decimal from collections import Counter from datetime import datetime import os, sys import json import psycopg2 import tqdm import traceback import argparse import sys import time sys.path.append('/src') import pandas as pd from spider.evaluation.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql, \ get_schema_from_json from spider.evaluation.output_format import Format # Flag to disable value evaluation DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') JOIN_KEYWORDS = ('join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { 'sql': "sql", 'table_unit': "table_unit", } COND_OPS = ('and', 'or') SQL_OPS = ('intersect', 'union', 'except') ORDER_OPS = ('desc', 'asc') HARDNESS = { "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), "component2": ('except', 'union', 'intersect') } def condition_has_or(conds): return 'or' in conds[1::2] def condition_has_like(conds): return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] def condition_has_sql(conds): for cond_unit in conds[::2]: val1, val2 = cond_unit[3], cond_unit[4] if val1 is not None and type(val1) is dict: return True if val2 is not None and type(val2) is dict: return True return False def val_has_op(val_unit): return val_unit[0] != UNIT_OPS.index('none') def has_agg(unit): return unit[0] != AGG_OPS.index('none') def accuracy(count, total): if count == total: return 1 return 0 def recall(count, total): if count == total: return 1 return 0 def F1(acc, rec): if (acc + rec) == 0: return 0 return (2. * acc * rec) / (acc + rec) def get_scores(count, pred_total, label_total): if pred_total != label_total: return 0,0,0 elif count == pred_total: return 1,1,1 return 0,0,0 def eval_sel(pred, label): pred_sel = pred['select'][1] label_sel = label['select'][1] label_wo_agg = [unit[1] for unit in label_sel] pred_total = len(pred_sel) label_total = len(label_sel) cnt = 0 cnt_wo_agg = 0 for unit in pred_sel: if unit in label_sel: cnt += 1 label_sel.remove(unit) if unit[1] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[1]) return label_total, pred_total, cnt, cnt_wo_agg def eval_where(pred, label): pred_conds = [unit for unit in pred['where'][::2]] label_conds = [unit for unit in label['where'][::2]] label_wo_agg = [unit[2] for unit in label_conds] pred_total = len(pred_conds) label_total = len(label_conds) cnt = 0 cnt_wo_agg = 0 for unit in pred_conds: if unit in label_conds: cnt += 1 label_conds.remove(unit) if unit[2] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[2]) return label_total, pred_total, cnt, cnt_wo_agg def eval_group(pred, label): pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] pred_total = len(pred_cols) label_total = len(label_cols) cnt = 0 pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] for col in pred_cols: if col in label_cols: cnt += 1 label_cols.remove(col) return label_total, pred_total, cnt def eval_having(pred, label): pred_total = label_total = cnt = 0 if len(pred['groupBy']) > 0: pred_total = 1 if len(label['groupBy']) > 0: label_total = 1 pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] if pred_total == label_total == 1 \ and pred_cols == label_cols \ and pred['having'] == label['having']: cnt = 1 return label_total, pred_total, cnt def eval_order(pred, label): pred_total = label_total = cnt = 0 if len(pred['orderBy']) > 0: pred_total = 1 if len(label['orderBy']) > 0: label_total = 1 if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): cnt = 1 return label_total, pred_total, cnt def eval_and_or(pred, label): pred_ao = pred['where'][1::2] label_ao = label['where'][1::2] pred_ao = set(pred_ao) label_ao = set(label_ao) if pred_ao == label_ao: return 1,1,1 return len(pred_ao),len(label_ao),0 def get_nestedSQL(sql): nested = [] for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: if type(cond_unit[3]) is dict: nested.append(cond_unit[3]) if type(cond_unit[4]) is dict: nested.append(cond_unit[4]) if sql['intersect'] is not None: nested.append(sql['intersect']) if sql['except'] is not None: nested.append(sql['except']) if sql['union'] is not None: nested.append(sql['union']) return nested def eval_nested(pred, label): label_total = 0 pred_total = 0 cnt = 0 if pred is not None: pred_total += 1 if label is not None: label_total += 1 if pred is not None and label is not None: cnt += Evaluator().eval_exact_match(pred, label) return label_total, pred_total, cnt def eval_IUEN(pred, label): lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) label_total = lt1 + lt2 + lt3 pred_total = pt1 + pt2 + pt3 cnt = cnt1 + cnt2 + cnt3 return label_total, pred_total, cnt def get_keywords(sql): res = set() if len(sql['where']) > 0: res.add('where') if len(sql['groupBy']) > 0: res.add('group') if len(sql['having']) > 0: res.add('having') if len(sql['orderBy']) > 0: res.add(sql['orderBy'][0]) res.add('order') if sql['limit'] is not None: res.add('limit') if sql['except'] is not None: res.add('except') if sql['union'] is not None: res.add('union') if sql['intersect'] is not None: res.add('intersect') # or keyword ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] if len([token for token in ao if token == 'or']) > 0: res.add('or') cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] # not keyword if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: res.add('not') # in keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: res.add('in') # like keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: res.add('like') return res def eval_keywords(pred, label): pred_keywords = get_keywords(pred) label_keywords = get_keywords(label) pred_total = len(pred_keywords) label_total = len(label_keywords) cnt = 0 for k in pred_keywords: if k in label_keywords: cnt += 1 return label_total, pred_total, cnt def count_agg(units): return len([unit for unit in units if has_agg(unit)]) def count_component1(sql): count = 0 if len(sql['where']) > 0: count += 1 if len(sql['groupBy']) > 0: count += 1 if len(sql['orderBy']) > 0: count += 1 if sql['limit'] is not None: count += 1 if len(sql['from']['table_units']) > 0: # JOIN count += len(sql['from']['table_units']) - 1 ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] count += len([token for token in ao if token == 'or']) cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) return count def count_component2(sql): nested = get_nestedSQL(sql) return len(nested) def count_others(sql): count = 0 # number of aggregation agg_count = count_agg(sql['select'][1]) agg_count += count_agg(sql['where'][::2]) agg_count += count_agg(sql['groupBy']) if len(sql['orderBy']) > 0: agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) agg_count += count_agg(sql['having']) if agg_count > 1: count += 1 # number of select columns if len(sql['select'][1]) > 1: count += 1 # number of where conditions if len(sql['where']) > 1: count += 1 # number of group by clauses if len(sql['groupBy']) > 1: count += 1 return count class Evaluator: """A simple evaluator""" def __init__(self): self.partial_scores = None def eval_hardness(self, sql): count_comp1_ = count_component1(sql) count_comp2_ = count_component2(sql) count_others_ = count_others(sql) if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: return "easy" elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): return "medium" elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): return "hard" else: return "extra" def eval_exact_match(self, pred, label): partial_scores = self.eval_partial_match(pred, label) self.partial_scores = partial_scores for _, score in partial_scores.items(): if score['f1'] != 1: return 0 if len(label['from']['table_units']) > 0: label_tables = sorted(label['from']['table_units']) pred_tables = sorted(pred['from']['table_units']) return label_tables == pred_tables return 1 def eval_partial_match(self, pred, label): res = {} label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['select'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['where'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt = eval_group(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt = eval_having(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['group'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt = eval_order(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['order'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt = eval_and_or(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt = eval_IUEN(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} label_total, pred_total, cnt = eval_keywords(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} return res def print_scores(scores, etype, tb_writer, training_step, print_stdout): levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] if etype in ["all", "match"]: matching_accuracy_with_counts = {level: "{} ({})".format(scores[level]['exact'], scores[level]['count']) for level in levels} matching_accuracy = {level: scores[level]['exact'] for level in levels} else: matching_accuracy_with_counts = {level: "{} ({})".format(scores[level]['exec'], scores[level]['count']) for level in levels} matching_accuracy = {level: scores[level]['exec'] for level in levels} if tb_writer: tb_writer.add_scalars('spider_evaluation_acc', matching_accuracy, training_step) if print_stdout: print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) counts = [scores[level]['count'] for level in levels] print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) if etype in ["all", "exec"]: print('===================== EXECUTION ACCURACY =====================') this_scores = [scores[level]['exec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) if etype in ["all", "match"]: print('\n====================== EXACT MATCHING ACCURACY =====================') exact_scores = [scores[level]['exact'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING RECALL ----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING F1 --------------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) return matching_accuracy_with_counts def spider_evaluation(gold, predict, db_dir, etype, kmaps, tb_writer=None, training_step=None, print_stdout=True): with open(gold, encoding='utf-8') as f: glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] with open(predict, encoding='utf-8') as f: plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] evaluator = Evaluator() levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] entries = [] scores = {} for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} schemas= {} eval_err_num = 0 for p, g in zip(plist, glist): p_str = p[0] global nl_q g_str, db, nl_q = g db_name = db if db_name not in schemas.keys(): schemas[db_name] = Schema(get_schema(db)) g_sql = get_sql(schemas[db_name], g_str) global hardness hardness = evaluator.eval_hardness(g_sql) scores[hardness]['count'] += 1 scores['all']['count'] += 1 try: p_sql = get_sql(schemas[db_name], p_str) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": { "conds": [], "table_units": [] }, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [ False, [] ], "union": None, "where": [] } eval_err_num += 1 print("eval_err_num:{}".format(eval_err_num)) #rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schemas[db_name]) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schemas[db_name]) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype in ["all", "exec"]: exec_score = eval_exec_match(db_name, p_str, g_str, p_sql, g_sql) if exec_score: scores[hardness]['exec'] += 1 scores['all']['exec'] += 1 else: if print_stdout: print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print("") if etype in ["all", "match"]: exact_score = evaluator.eval_exact_match(sparql, g_sql) partial_scores = evaluator.partial_scores if exact_score == 0: if print_stdout: print("{} pred: {}".format(hardness, p_str)) print("{} gold: {}".format(hardness, g_str)) print("") scores[hardness]['exact'] += exact_score scores['all']['exact'] += exact_score for type_ in partial_types: if partial_scores[type_]['pred_total'] > 0: scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores[hardness]['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores[hardness]['partial'][type_]['rec_count'] += 1 scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] if partial_scores[type_]['pred_total'] > 0: scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] scores['all']['partial'][type_]['acc_count'] += 1 if partial_scores[type_]['label_total'] > 0: scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] scores['all']['partial'][type_]['rec_count'] += 1 scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] entries.append({ 'predictSQL': p_str, 'goldSQL': g_str, 'hardness': hardness, 'exact': exact_score, 'partial': partial_scores }) Format.exportDf_SQL() for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) return print_scores(scores, etype, tb_writer, training_step, print_stdout) def psql_conn(db_name, query): conn = None counter = 0 while conn is None: try: conn = psycopg2.connect(dbname="spider", user="postgres", password="vdS83DJSQz2xQ",host="testbed.inode.igd.fraunhofer.de", port="18001", options=f"-c search_path={db_name}") except Exception as e: counter = counter + 1 if counter > 3: raise e q_values = [] try: cursor = conn.cursor() cursor.execute(query) q_values = cursor.fetchall() cursor.close() except Exception as e: print("could not execute query '{}'. Exception: {}. Database: {}".format(query, e, db_name)) finally: conn.close() return q_values def get_sql_values(q_values): #this function gets the values from the sql connection, since there are two sql statements being run, it made more sense to have a function q_res = [] for q_value in q_values: counter = 0 q_value = list(q_value) for q in q_value: if isinstance(q, decimal.Decimal): q_value[counter] = round(float(q), 12) elif isinstance(q, float): q_value[counter] = round(q, 12) elif isinstance(q, datetime): q_value[counter] = datetime.strftime(q, '%Y-%m-%d %H:%M:%S') counter += 1 q_res.append(tuple(q_value)) return q_res def eval_exec_match(db_name, p_str, g_str, p_sql, g_sql): """ return 1 if the values between prediction and gold are matching in the corresponding index. Currently not support multiple col_unit(pairs). """ # print("db: {} sql-pred: {} sql-gold: {}".format(db, sparql_str, g_str)) exception = '' p_values = psql_conn(db_name, p_str) p_res = get_sql_values(p_values) q_values = psql_conn(db_name, g_str) q_res = get_sql_values(q_values) eval_all(db_name, nl_q, p_res, q_res, p_str, g_str, exception) def compare_lists(firstList, secondList): return Counter(firstList) == Counter(secondList) def eval_all(db_name,nl_q, p_res, q_res, p_str, g_str, exception): try: if exception is not None and exception != '': Format.appendRow_SQL(db_name, nl_q, g_str, q_res, p_str, p_res, hardness, exception=exception) elif compare_lists(q_res, p_res): Format.appendRow_SQL(db_name, nl_q, g_str, q_res, p_str, p_res, hardness, isMatch=True, score=1) else: Format.appendRow_SQL(db_name, nl_q, g_str, q_res, p_str, p_res, hardness, isMatch=False, score=0) except Exception as e: print("Type Error. p_res:{}, q_res:{}".format(type(p_res), type(q_res))) return False def res_map(res, val_units): """ Notes Ursin: The val_units are the columns in the select and "res" is the results from the select, one tuple per row. What we do now is building up a map with columns as key and a list of values from the sql-select. Then we simply compare the two maps. That way we don't care about the selection-order. """ # zip & iterate over both, val_units and res and compare rmap = {} for idx, val_unit in enumerate(val_units): key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) rmap[key] = [r[idx] for r in res] return rmap p_val_units = [unit[1] for unit in pred['select'][1]] q_val_units = [unit[1] for unit in gold['select'][1]] # it's important to remember that when comparing two maps, the order of the keys doesn't matter return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) # Rebuild SQL functions for value evaluation def rebuild_cond_unit_val(cond_unit): if cond_unit is None or not DISABLE_VALUE: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit if type(val1) is not dict: val1 = None else: val1 = rebuild_sql_val(val1) if type(val2) is not dict: val2 = None else: val2 = rebuild_sql_val(val2) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_val(condition): if condition is None or not DISABLE_VALUE: return condition res = [] for idx, it in enumerate(condition): if idx % 2 == 0: res.append(rebuild_cond_unit_val(it)) else: res.append(it) return res def rebuild_sql_val(sql): if sql is None or not DISABLE_VALUE: return sql sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) sql['having'] = rebuild_condition_val(sql['having']) sql['where'] = rebuild_condition_val(sql['where']) sql['intersect'] = rebuild_sql_val(sql['intersect']) sql['except'] = rebuild_sql_val(sql['except']) sql['union'] = rebuild_sql_val(sql['union']) return sql # Rebuild SQL functions for foreign key evaluation def build_valid_col_units(table_units, schema): col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] prefixs = [col_id[:-2] for col_id in col_ids] valid_col_units= [] for value in schema.idMap.values(): if '.' in value and value[:value.index('.')] in prefixs: valid_col_units.append(value) return valid_col_units def rebuild_col_unit_col(valid_col_units, col_unit, kmap): if col_unit is None: return col_unit agg_id, col_id, distinct = col_unit if col_id in kmap and col_id in valid_col_units: col_id = kmap[col_id] if DISABLE_DISTINCT: distinct = None return agg_id, col_id, distinct def rebuild_val_unit_col(valid_col_units, val_unit, kmap): if val_unit is None: return val_unit unit_op, col_unit1, col_unit2 = val_unit col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) return unit_op, col_unit1, col_unit2 def rebuild_table_unit_col(valid_col_units, table_unit, kmap): if table_unit is None: return table_unit table_type, col_unit_or_sql = table_unit if isinstance(col_unit_or_sql, tuple): col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) return table_type, col_unit_or_sql def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): if cond_unit is None: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_col(valid_col_units, condition, kmap): for idx in range(len(condition)): if idx % 2 == 0: condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) return condition def rebuild_select_col(valid_col_units, sel, kmap): if sel is None: return sel distinct, _list = sel new_list = [] for it in _list: agg_id, val_unit = it new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) if DISABLE_DISTINCT: distinct = None return distinct, new_list def rebuild_from_col(valid_col_units, from_, kmap): if from_ is None: return from_ from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) return from_ def rebuild_group_by_col(valid_col_units, group_by, kmap): if group_by is None: return group_by return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] def rebuild_order_by_col(valid_col_units, order_by, kmap): if order_by is None or len(order_by) == 0: return order_by direction, val_units = order_by new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] return direction, new_val_units def rebuild_sql_col(valid_col_units, sql, kmap): if sql is None: return sql sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) return sql def build_foreign_key_map(entry): cols_orig = entry["column_names_original"] tables_orig = entry["table_names_original"] # rebuild cols corresponding to idmap in Schema cols = [] for col_orig in cols_orig: if col_orig[0] >= 0: t = tables_orig[col_orig[0]] c = col_orig[1] cols.append("__" + t.lower() + "." + c.lower() + "__") else: cols.append("__all__") def keyset_in_list(k1, k2, k_list): for k_set in k_list: if k1 in k_set or k2 in k_set: return k_set new_k_set = set() k_list.append(new_k_set) return new_k_set foreign_key_list = [] foreign_keys = entry["foreign_keys"] for fkey in foreign_keys: key1, key2 = fkey key_set = keyset_in_list(key1, key2, foreign_key_list) key_set.add(key1) key_set.add(key2) foreign_key_map = {} for key_set in foreign_key_list: sorted_list = sorted(list(key_set)) midx = sorted_list[0] for idx in sorted_list: foreign_key_map[cols[idx]] = cols[midx] return foreign_key_map def build_foreign_key_map_from_json(table): with open(table, encoding='utf-8') as f: data = json.load(f) tables = {} for entry in data: tables[entry['db_id']] = build_foreign_key_map(entry) return tables if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gold', dest='gold', type=str) parser.add_argument('--pred', dest='pred', type=str) parser.add_argument('--db', dest='db', type=str) parser.add_argument('--table', dest='table', type=str) parser.add_argument('--etype', dest='etype', type=str) args = parser.parse_args() gold = args.gold pred = args.pred db_dir = args.db table = args.table etype = args.etype assert etype in ["all", "exec", "match"], "Unknown evaluation method" kmaps = build_foreign_key_map_from_json(table) spider_evaluation(gold, pred, db_dir, etype, kmaps)