################################ # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import decimal from collections import Counter from datetime import datetime import os, sys import json import psycopg2 from SPARQLWrapper import SPARQLWrapper, JSON, SmartWrapper, SPARQLWrapper2 import traceback import argparse import sys import time sys.path.append('/src') import pandas as pd from spider.evaluation.process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql, \ get_schema_from_json from spider.evaluation.output_format import Format # Flag to disable value evaluation DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') JOIN_KEYWORDS = ('join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { 'sql': "sql", 'table_unit': "table_unit", } COND_OPS = ('and', 'or') SQL_OPS = ('intersect', 'union', 'except') ORDER_OPS = ('desc', 'asc') HARDNESS = { "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), "component2": ('except', 'union', 'intersect') } def condition_has_or(conds): return 'or' in conds[1::2] def condition_has_like(conds): return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] def condition_has_sql(conds): for cond_unit in conds[::2]: val1, val2 = cond_unit[3], cond_unit[4] if val1 is not None and type(val1) is dict: return True if val2 is not None and type(val2) is dict: return True return False def val_has_op(val_unit): return val_unit[0] != UNIT_OPS.index('none') def has_agg(unit): return unit[0] != AGG_OPS.index('none') def accuracy(count, total): if count == total: return 1 return 0 def recall(count, total): if count == total: return 1 return 0 def F1(acc, rec): if (acc + rec) == 0: return 0 return (2. * acc * rec) / (acc + rec) def get_scores(count, pred_total, label_total): if pred_total != label_total: return 0, 0, 0 elif count == pred_total: return 1, 1, 1 return 0, 0, 0 def eval_sel(pred, label): pred_sel = pred['select'][1] label_sel = label['select'][1] label_wo_agg = [unit[1] for unit in label_sel] pred_total = len(pred_sel) label_total = len(label_sel) cnt = 0 cnt_wo_agg = 0 for unit in pred_sel: if unit in label_sel: cnt += 1 label_sel.remove(unit) if unit[1] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[1]) return label_total, pred_total, cnt, cnt_wo_agg def eval_where(pred, label): pred_conds = [unit for unit in pred['where'][::2]] label_conds = [unit for unit in label['where'][::2]] label_wo_agg = [unit[2] for unit in label_conds] pred_total = len(pred_conds) label_total = len(label_conds) cnt = 0 cnt_wo_agg = 0 for unit in pred_conds: if unit in label_conds: cnt += 1 label_conds.remove(unit) if unit[2] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[2]) return label_total, pred_total, cnt, cnt_wo_agg def eval_group(pred, label): pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] pred_total = len(pred_cols) label_total = len(label_cols) cnt = 0 pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] for col in pred_cols: if col in label_cols: cnt += 1 label_cols.remove(col) return label_total, pred_total, cnt def eval_having(pred, label): pred_total = label_total = cnt = 0 if len(pred['groupBy']) > 0: pred_total = 1 if len(label['groupBy']) > 0: label_total = 1 pred_cols = [unit[1] for unit in pred['groupBy']] label_cols = [unit[1] for unit in label['groupBy']] if pred_total == label_total == 1 \ and pred_cols == label_cols \ and pred['having'] == label['having']: cnt = 1 return label_total, pred_total, cnt def eval_order(pred, label): pred_total = label_total = cnt = 0 if len(pred['orderBy']) > 0: pred_total = 1 if len(label['orderBy']) > 0: label_total = 1 if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ ((pred['limit'] is None and label['limit'] is None) or ( pred['limit'] is not None and label['limit'] is not None)): cnt = 1 return label_total, pred_total, cnt def eval_and_or(pred, label): pred_ao = pred['where'][1::2] label_ao = label['where'][1::2] pred_ao = set(pred_ao) label_ao = set(label_ao) if pred_ao == label_ao: return 1, 1, 1 return len(pred_ao), len(label_ao), 0 def get_nestedSQL(sql): nested = [] for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: if type(cond_unit[3]) is dict: nested.append(cond_unit[3]) if type(cond_unit[4]) is dict: nested.append(cond_unit[4]) if sql['intersect'] is not None: nested.append(sql['intersect']) if sql['except'] is not None: nested.append(sql['except']) if sql['union'] is not None: nested.append(sql['union']) return nested def eval_nested(pred, label): label_total = 0 pred_total = 0 cnt = 0 if pred is not None: pred_total += 1 if label is not None: label_total += 1 if pred is not None and label is not None: cnt += Evaluator().eval_exact_match(pred, label) return label_total, pred_total, cnt def eval_IUEN(pred, label): lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) label_total = lt1 + lt2 + lt3 pred_total = pt1 + pt2 + pt3 cnt = cnt1 + cnt2 + cnt3 return label_total, pred_total, cnt def get_keywords(sql): res = set() if len(sql['where']) > 0: res.add('where') if len(sql['groupBy']) > 0: res.add('group') if len(sql['having']) > 0: res.add('having') if len(sql['orderBy']) > 0: res.add(sql['orderBy'][0]) res.add('order') if sql['limit'] is not None: res.add('limit') if sql['except'] is not None: res.add('except') if sql['union'] is not None: res.add('union') if sql['intersect'] is not None: res.add('intersect') # or keyword ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] if len([token for token in ao if token == 'or']) > 0: res.add('or') cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] # not keyword if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: res.add('not') # in keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: res.add('in') # like keyword if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: res.add('like') return res def eval_keywords(pred, label): pred_keywords = get_keywords(pred) label_keywords = get_keywords(label) pred_total = len(pred_keywords) label_total = len(label_keywords) cnt = 0 for k in pred_keywords: if k in label_keywords: cnt += 1 return label_total, pred_total, cnt def count_agg(units): return len([unit for unit in units if has_agg(unit)]) def count_component1(sql): count = 0 if len(sql['where']) > 0: count += 1 if len(sql['groupBy']) > 0: count += 1 if len(sql['orderBy']) > 0: count += 1 if sql['limit'] is not None: count += 1 if len(sql['from']['table_units']) > 0: # JOIN count += len(sql['from']['table_units']) - 1 ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] count += len([token for token in ao if token == 'or']) cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) return count def count_component2(sql): nested = get_nestedSQL(sql) return len(nested) def count_others(sql): count = 0 # number of aggregation agg_count = count_agg(sql['select'][1]) agg_count += count_agg(sql['where'][::2]) agg_count += count_agg(sql['groupBy']) if len(sql['orderBy']) > 0: agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) agg_count += count_agg(sql['having']) if agg_count > 1: count += 1 # number of select columns if len(sql['select'][1]) > 1: count += 1 # number of where conditions if len(sql['where']) > 1: count += 1 # number of group by clauses if len(sql['groupBy']) > 1: count += 1 return count class Evaluator: """A simple evaluator""" def __init__(self): self.partial_scores = None def eval_hardness(self, sql): count_comp1_ = count_component1(sql) count_comp2_ = count_component2(sql) count_others_ = count_others(sql) if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: return "easy" elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): return "medium" elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): return "hard" else: return "extra" # def eval_exact_match(self, pred, label): # partial_scores = self.eval_partial_match(pred, label) # self.partial_scores = partial_scores # # for _, score in partial_scores.items(): # if score['f1'] != 1: # return 0 # if len(label['from']['table_units']) > 0: # label_tables = sorted(label['from']['table_units']) # pred_tables = sorted(pred['from']['table_units']) # return label_tables == pred_tables # return 1 # # def eval_partial_match(self, pred, label): # res = {} # # label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['select'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) # res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['where'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) # res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # label_total, pred_total, cnt = eval_group(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, # 'pred_total': pred_total} # # label_total, pred_total, cnt = eval_having(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['group'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # label_total, pred_total, cnt = eval_order(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['order'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # label_total, pred_total, cnt = eval_and_or(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # label_total, pred_total, cnt = eval_IUEN(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # label_total, pred_total, cnt = eval_keywords(pred, label) # acc, rec, f1 = get_scores(cnt, pred_total, label_total) # res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1, 'label_total': label_total, 'pred_total': pred_total} # # return res def print_scores(scores, etype, tb_writer, training_step, print_stdout): levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] if etype in ["all", "match"]: matching_accuracy_with_counts = {level: "{} ({})".format(scores[level]['exact'], scores[level]['count']) for level in levels} matching_accuracy = {level: scores[level]['exact'] for level in levels} else: matching_accuracy_with_counts = {level: "{} ({})".format(scores[level]['exec'], scores[level]['count']) for level in levels} matching_accuracy = {level: scores[level]['exec'] for level in levels} if tb_writer: tb_writer.add_scalars('spider_evaluation_acc', matching_accuracy, training_step) wandb.log(matching_accuracy, step=training_step) if print_stdout: print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) counts = [scores[level]['count'] for level in levels] print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) if etype in ["all", "exec"]: print('===================== EXECUTION ACCURACY =====================') this_scores = [scores[level]['exec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) if etype in ["all", "match"]: print('\n====================== EXACT MATCHING ACCURACY =====================') exact_scores = [scores[level]['exact'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING RECALL ----------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) print('---------------------- PARTIAL MATCHING F1 --------------------------') for type_ in partial_types: this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) return matching_accuracy_with_counts def spider_evaluation(gold, predict_Sparql, etype, kmaps, tb_writer=None, training_step=None, print_stdout=True,): with open(gold, encoding='utf-8') as f: glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] # with open(predict_SQL, encoding='utf-8') as f: # psql_list = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] with open(predict_Sparql, encoding='utf-8') as f: sparql_list = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] evaluator = Evaluator() levels = ['easy', 'medium', 'hard', 'extra', 'all'] partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', 'group', 'order', 'and/or', 'IUEN', 'keywords'] entries = [] scores = {} for level in levels: scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} scores[level]['exec'] = 0 for type_ in partial_types: scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0., 'acc_count': 0, 'rec_count': 0} db_exception = pd.read_excel('src/experiments/SparqlPrediction/false_negatives.xlsx', sheet_name=None) false_neg_list = db_exception['false_negatives'].iloc[:, 0].values schemas= {} for sparql, g in zip(sparql_list, glist): global nl_q sparql_str = sparql[0] g_str, db, nl_q = g global db_name db_name = db if db_name not in schemas.keys(): schemas[db_name] = Schema(get_schema(db)) g_sql = get_sql(schemas[db_name], g_str) global hardness hardness = evaluator.eval_hardness(g_sql) scores[hardness]['count'] += 1 scores['all']['count'] += 1 # try: # p_sql = get_sql(schema, p_str) # except: # # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql # p_sql = { # "except": None, # "from": { # "conds": [], # "table_units": [] # }, # "groupBy": [], # "having": [], # "intersect": None, # "limit": None, # "orderBy": [], # "select": [ # False, # [] # ], # "union": None, # "where": [] # } # eval_err_num += 1 # print("eval_err_num:{}".format(eval_err_num)) # rebuild sql for value evaluation kmap = kmaps[db_name] g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schemas[db_name]) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) # p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema) # p_sql = rebuild_sql_val(p_sql) # p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) if etype in ["all", "exec"]: exec_score = eval_exec_match(db_name, sparql_str, g_str, false_neg_list) if exec_score: scores[hardness]['exec'] += 1 scores['all']['exec'] += 1 else: if print_stdout: print("{} gold: {}".format(hardness, g_str)) print("") # if etype in ["all", "match"]: # exact_score = evaluator.eval_exact_match(sparql, g_sql) # partial_scores = evaluator.partial_scores # if exact_score == 0: # if print_stdout: # print("{} pred: {}".format(hardness, sparql_str)) # print("{} gold: {}".format(hardness, g_str)) # print("") # scores[hardness]['exact'] += exact_score # scores['all']['exact'] += exact_score # for type_ in partial_types: # if partial_scores[type_]['pred_total'] > 0: # scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] # scores[hardness]['partial'][type_]['acc_count'] += 1 # if partial_scores[type_]['label_total'] > 0: # scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] # scores[hardness]['partial'][type_]['rec_count'] += 1 # scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] # if partial_scores[type_]['pred_total'] > 0: # scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] # scores['all']['partial'][type_]['acc_count'] += 1 # if partial_scores[type_]['label_total'] > 0: # scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] # scores['all']['partial'][type_]['rec_count'] += 1 # scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] # # entries.append({ # 'predictSQL': sparql_str, # 'goldSQL': g_str, # 'hardness': hardness, # 'exact': exact_score, # 'partial': partial_scores # }) Format.exportDf() for level in levels: if scores[level]['count'] == 0: continue if etype in ["all", "exec"]: scores[level]['exec'] /= scores[level]['count'] if etype in ["all", "match"]: scores[level]['exact'] /= scores[level]['count'] for type_ in partial_types: if scores[level]['partial'][type_]['acc_count'] == 0: scores[level]['partial'][type_]['acc'] = 0 else: scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ scores[level]['partial'][type_]['acc_count'] * 1.0 if scores[level]['partial'][type_]['rec_count'] == 0: scores[level]['partial'][type_]['rec'] = 0 else: scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ scores[level]['partial'][type_]['rec_count'] * 1.0 if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: scores[level]['partial'][type_]['f1'] = 1 else: scores[level]['partial'][type_]['f1'] = \ 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) return print_scores(scores, etype, tb_writer, training_step, print_stdout) def psql_conn(db_name, query): conn = None counter = 0 while conn is None: try: conn = psycopg2.connect(dbname="spider", user="postgres", password="vdS83DJSQz2xQ",host="testbed.inode.igd.fraunhofer.de", port="18001", options=f"-c search_path={db_name}") except Exception as e: counter = counter + 1 if counter > 3: raise e q_values = [] try: cursor = conn.cursor() cursor.execute(query) q_values = cursor.fetchall() cursor.close() finally: conn.close() return q_values def get_sql_values(q_values): #this function gets the values from the sql connection, since there are two sql statements being run, it made more sense to have a function q_res = [] for q_value in q_values: counter = 0 q_value = list(q_value) for q in q_value: if isinstance(q, decimal.Decimal): q_value[counter] = round(float(q), 12) elif isinstance(q, float): q_value[counter] = round(q, 12) elif isinstance(q, datetime): q_value[counter] = datetime.strftime(q, '%Y-%m-%d %H:%M:%S') counter += 1 q_res.append(tuple(q_value)) return q_res def execute_sparql(query, db_name): sparql = SPARQLWrapper2('http://160.85.254.15:7200/repositories/' + db_name.lower()) sparql.setReturnFormat(JSON) sparql.setQuery('PREFIX dbo: <http://valuenet/ontop/> '+ query) # the previous query as a literal string results = sparql.query() vars = results.variables lists = [results.getValues(var) for var in vars] return list(zip(*lists)) def eval_exec_match(db_name, sparql_str, g_str, false_neg_list): """ return 1 if the values between prediction and gold are matching in the corresponding index. Currently not support multiple col_unit(pairs). """ # print("db: {} sql-pred: {} sql-gold: {}".format(db, sparql_str, g_str)) #kg = 'http://160.85.254.15:7200/repositories/' + db_name.lower() #prefix = 'PREFIX dbo: <http://valuenet/ontop/> ' #conn = sparql.Service(kg, "utf-8", "POST", "application/sparql-results+xml") exception = '' p_res = [] q_res = [] result = None try: #conn = sparql.Service(kg, "POST") i = 0 is_successful = False timeout = 5 while i < 3 and not is_successful: try: i += 1 results = execute_sparql(sparql_str, db_name) is_successful = True except Exception as e: timeout = timeout if i == 3: raise e if is_successful: for result in results: result_length = len(result) values= [int(result[i].value) if result[i].datatype== 'http://www.w3.org/2001/XMLSchema#integer' else result[i].value for i in range(result_length) ] # if results[0][0].datatype == 'http://www.w3.org/2001/XMLSchema#integer': # values = [int(results[0][0].value)] # else: #values = [result[idx].value for idx, result in enumerate(results)] counter = 0 for value in values: if isinstance(value, decimal.Decimal): values[counter] = round(float(value), 12) elif isinstance(value, float): values[counter] = round(value, 12) elif isinstance(value, datetime): values[counter] = datetime.strftime(value, '%Y-%m-%d %H:%M:%S') counter += 1 p_res.append(tuple(values)) q_values = psql_conn(db_name, g_str) q_res = get_sql_values(q_values) except Exception as e: traceback.print_exc() print("could not execute query '\n\r{}\n\r'. Exception: {}. Database: {}".format(sparql_str, e, db_name)) exception = e # return False eval_all(p_res, q_res, sparql_str, g_str, exception, false_neg_list) print("natural language Q: {}".format(nl_q)) print("database: {}".format(db_name)) print("sparql: {}".format(sparql_str)) def compare_lists(firstList, secondList): return Counter(firstList) == Counter(secondList) def eval_all(sparql_res, q_res, sparql_str, g_str, exception, false_neg_list): try: if g_str in false_neg_list: Format.appendRow(db_name, nl_q, g_str, q_res, sparql_str, sparql_res, hardness, isMatch=True, score=1, exception='false_negative') elif exception is not None and exception != '': Format.appendRow(db_name, nl_q, g_str, q_res, sparql_str, sparql_res, hardness, exception=exception) elif compare_lists(q_res, sparql_res): Format.appendRow(db_name, nl_q, g_str, q_res, sparql_str, sparql_res, hardness, isMatch=True, score=1) else: Format.appendRow(db_name, nl_q, g_str, q_res, sparql_str, sparql_res, hardness, isMatch=False, score=0) except Exception as e: print("Type Error. p_res:{}, q_res:{}".format(type(sparql_res), type(q_res))) return False def res_map(res, val_units): """ Notes Ursin: The val_units are the columns in the select and "res" is the results from the select, one tuple per row. What we do now is building up a map with columns as key and a list of values from the sql-select. Then we simply compare the two maps. That way we don't care about the selection-order. """ # zip & iterate over both, val_units and res and compare rmap = {} for idx, val_unit in enumerate(val_units): key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2])) rmap[key] = [r[idx] for r in res] return rmap p_val_units = [unit[1] for unit in pred['select'][1]] q_val_units = [unit[1] for unit in gold['select'][1]] # it's important to remember that when comparing two maps, the order of the keys doesn't matter return res_map(p_res, p_val_units) == res_map(q_res, q_val_units) # Rebuild SQL functions for value evaluation def rebuild_cond_unit_val(cond_unit): if cond_unit is None or not DISABLE_VALUE: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit if type(val1) is not dict: val1 = None else: val1 = rebuild_sql_val(val1) if type(val2) is not dict: val2 = None else: val2 = rebuild_sql_val(val2) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_val(condition): if condition is None or not DISABLE_VALUE: return condition res = [] for idx, it in enumerate(condition): if idx % 2 == 0: res.append(rebuild_cond_unit_val(it)) else: res.append(it) return res def rebuild_sql_val(sql): if sql is None or not DISABLE_VALUE: return sql sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) sql['having'] = rebuild_condition_val(sql['having']) sql['where'] = rebuild_condition_val(sql['where']) sql['intersect'] = rebuild_sql_val(sql['intersect']) sql['except'] = rebuild_sql_val(sql['except']) sql['union'] = rebuild_sql_val(sql['union']) return sql # Rebuild SQL functions for foreign key evaluation def build_valid_col_units(table_units, schema): col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] prefixs = [col_id[:-2] for col_id in col_ids] valid_col_units = [] for value in schema.idMap.values(): if '.' in value and value[:value.index('.')] in prefixs: valid_col_units.append(value) return valid_col_units def rebuild_col_unit_col(valid_col_units, col_unit, kmap): if col_unit is None: return col_unit agg_id, col_id, distinct = col_unit if col_id in kmap and col_id in valid_col_units: col_id = kmap[col_id] if DISABLE_DISTINCT: distinct = None return agg_id, col_id, distinct def rebuild_val_unit_col(valid_col_units, val_unit, kmap): if val_unit is None: return val_unit unit_op, col_unit1, col_unit2 = val_unit col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) return unit_op, col_unit1, col_unit2 def rebuild_table_unit_col(valid_col_units, table_unit, kmap): if table_unit is None: return table_unit table_type, col_unit_or_sql = table_unit if isinstance(col_unit_or_sql, tuple): col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) return table_type, col_unit_or_sql def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): if cond_unit is None: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_col(valid_col_units, condition, kmap): for idx in range(len(condition)): if idx % 2 == 0: condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) return condition def rebuild_select_col(valid_col_units, sel, kmap): if sel is None: return sel distinct, _list = sel new_list = [] for it in _list: agg_id, val_unit = it new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) if DISABLE_DISTINCT: distinct = None return distinct, new_list def rebuild_from_col(valid_col_units, from_, kmap): if from_ is None: return from_ from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) return from_ def rebuild_group_by_col(valid_col_units, group_by, kmap): if group_by is None: return group_by return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] def rebuild_order_by_col(valid_col_units, order_by, kmap): if order_by is None or len(order_by) == 0: return order_by direction, val_units = order_by new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] return direction, new_val_units def rebuild_sql_col(valid_col_units, sql, kmap): if sql is None: return sql sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) return sql def build_foreign_key_map(entry): cols_orig = entry["column_names_original"] tables_orig = entry["table_names_original"] # rebuild cols corresponding to idmap in Schema cols = [] for col_orig in cols_orig: if col_orig[0] >= 0: t = tables_orig[col_orig[0]] c = col_orig[1] cols.append("__" + t.lower() + "." + c.lower() + "__") else: cols.append("__all__") def keyset_in_list(k1, k2, k_list): for k_set in k_list: if k1 in k_set or k2 in k_set: return k_set new_k_set = set() k_list.append(new_k_set) return new_k_set foreign_key_list = [] foreign_keys = entry["foreign_keys"] for fkey in foreign_keys: key1, key2 = fkey key_set = keyset_in_list(key1, key2, foreign_key_list) key_set.add(key1) key_set.add(key2) foreign_key_map = {} for key_set in foreign_key_list: sorted_list = sorted(list(key_set)) midx = sorted_list[0] for idx in sorted_list: foreign_key_map[cols[idx]] = cols[midx] return foreign_key_map def build_foreign_key_map_from_json(table): with open(table, encoding='utf-8') as f: data = json.load(f) tables = {} for entry in data: tables[entry['db_id']] = build_foreign_key_map(entry) return tables if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gold', dest='gold', type=str) parser.add_argument('--pred', dest='pred', type=str) parser.add_argument('--db', dest='db', type=str) parser.add_argument('--table', dest='table', type=str) parser.add_argument('--etype', dest='etype', type=str) args = parser.parse_args() gold = args.gold pred = args.pred db_dir = args.db table = args.table etype = args.etype #assert etype in ["all", "exec", "match"], "Unknown evaluation method" kmaps = build_foreign_key_map_from_json(table) spider_evaluation(gold, pred, db_dir, etype, kmaps, schemaFile)