import copy import os from pathlib import Path import torch import wandb from tqdm import tqdm import json from config import read_arguments_evaluation from data_loader import get_data_loader from intermediate_representation import semQL from intermediate_representation.sem2sql.sem2SQL import transform_semQL_to_sql #from manual_inference.helper import _execute_query_postgresql from model.model import IRNet from spider import spider_utils from spider.example_builder import build_example from spider.test_suite_eval.exec_eval import result_eq from utils import setup_device, set_seed_everywhere import spider.test_suite_eval.evaluation as spider_evaluation def evaluate(model, dev_loader, schema, beam_size): model.eval() sketch_correct, rule_label_correct, found_in_beams, not_all_values_found, total = 0, 0, 0, 0, 0 predictions = [] for batch in tqdm(dev_loader, desc="Evaluating"): for data_row in batch: original_row = copy.deepcopy(data_row) try: example = build_example(data_row, schema) except Exception as e: print("Exception while building example (evaluation): {}".format(e)) continue with torch.no_grad(): results_all = model.parse(example, beam_size=beam_size) results = results_all[0] all_predictions = [] try: # here we set assemble the predicted actions (including leaf-nodes) as string full_prediction = " ".join([str(x) for x in results[0].actions]) for beam in results: all_predictions.append(" ".join([str(x) for x in beam.actions])) except Exception as e: # print(e) full_prediction = "" prediction = original_row # here we set assemble the predicted sketch actions as string prediction['sketch_result'] = " ".join(str(x) for x in results_all[1]) prediction['model_result'] = full_prediction truth_sketch = " ".join([str(x) for x in example.sketch]) truth_rule_label = " ".join([str(x) for x in example.semql_actions]) if prediction['all_values_found']: if truth_sketch == prediction['sketch_result']: sketch_correct += 1 if truth_rule_label == prediction['model_result']: rule_label_correct += 1 elif truth_rule_label in all_predictions: found_in_beams += 1 else: question = prediction['question'] tqdm.write(f'Not all values found during pre-processing for question "{question}". Replace values with dummy to make query fail') prediction['values'] = [1] * len(prediction['values']) not_all_values_found += 1 total += 1 predictions.append(prediction) print(f"in {found_in_beams} times we found the correct results in another beam (failing queries: {total - rule_label_correct})") return float(sketch_correct) / float(total), float(rule_label_correct) / float(total), float(not_all_values_found) / float(total), predictions def transform_to_sql_and_evaluate_with_spider(predictions, table_data, experiment_dir, data_dir, training_step): total_count, failure_count = transform_semQL_to_sql(table_data, predictions, experiment_dir) spider_eval_results = spider_evaluation.evaluate(os.path.join(experiment_dir, 'ground_truth.txt'), os.path.join(experiment_dir, 'output.txt'), os.path.join(data_dir, "testsuite_databases"), "exec", None, False, False, False, training_step, quickmode=False) return total_count, failure_count, spider_eval_results def _remove_unnecessary_distinct(p, g): p_tokens = p.split(' ') g_tokens = g.split(' ') if 'distinct' not in g_tokens and 'DISTINCT' not in g_tokens: return ' '.join([p_token for p_token in p_tokens if p_token != 'DISTINCT']) else: return p def evaluate_cordis(groundtruth_path: Path, prediction_path: Path, database: str, connection_config: dict, do_not_verify_distinct=True): with open(groundtruth_path, 'r', encoding='utf-8') as f: groundtruth_full_lines = f.readlines() groundtruth = [g.strip().split('\t')[0] for g in groundtruth_full_lines] questions = [g.strip().split('\t')[2] for g in groundtruth_full_lines] with open(prediction_path, 'r', encoding='utf-8') as f: prediction = f.readlines() n_success = 0 for p, g, q in zip(prediction, groundtruth, questions): if do_not_verify_distinct: p = _remove_unnecessary_distinct(p, g) try: results_prediction = _execute_query_postgresql(p, database, connection_config) except Exception as ex: print("Prediction is not a valid SQL query:") print(f"Q: {q.strip()}") print(f"P: {p.strip()}") print(f"G: {g.strip()}") print() continue results_groundtruth = _execute_query_postgresql(g, database, connection_config) if result_eq(results_groundtruth, results_prediction, order_matters=True): n_success += 1 else: print("Results not equal for:") print(f"Q: {q.strip()}") print(f"P: {p.strip()}") print(f"G: {g.strip()}") print() print(f"================== Total score: {n_success} of {len(groundtruth)} ({100 / len(groundtruth) * n_success}%) ==================") def main(): args = read_arguments_evaluation() device, n_gpu = setup_device() set_seed_everywhere(args.seed, n_gpu) sql_data, table_data, val_sql_data, val_table_data = spider_utils.load_dataset(args.data_dir, use_small=False) _, dev_loader = get_data_loader(sql_data, val_sql_data, args.batch_size, True, False) grammar = semQL.Grammar() model = IRNet(args, device, grammar) model.to(device) # load the pre-trained parameters model.load_state_dict(torch.load(args.model_to_load)) print("Load pre-trained model from '{}'".format(args.model_to_load)) sketch_acc, acc, not_all_values_found, predictions = evaluate(model, dev_loader, table_data, args.beam_size) print("Predicted {} examples. Start now converting them to SQL. Sketch-Accuracy: {}, Accuracy: {}, Not all values found: {}".format( len(dev_loader), sketch_acc, acc, not_all_values_found)) with open(os.path.join(args.prediction_dir, 'predictions_sem_ql.json'), 'w', encoding='utf-8') as f: json.dump(predictions, f, indent=2) # with open(os.path.join(args.prediction_dir, 'predictions_sem_ql.json'), 'r', encoding='utf-8') as json_file: # predictions = json.load(json_file) count_success, count_failed = transform_semQL_to_sql(val_table_data, predictions, args.prediction_dir) print("Transformed {} samples successful to SQL. {} samples failed. Generated the files a 'ground_truth.txt' " "and a 'output.txt' file.".format( count_success, count_failed)) if args.evaluation_type == 'spider': print('We now use the official Spider evaluation script to evaluate the generated/ground truth files.') wandb.init(project="proton") spider_evaluation.evaluate(os.path.join(args.prediction_dir, 'ground_truth.txt'), os.path.join(args.prediction_dir, 'output.txt'), os.path.join(args.data_dir, "testsuite_databases"), 'exec', None, False, False, False, 1, quickmode=False) elif args.evaluation_type == 'cordis': connection_config = {k: v for k, v in vars(args).items() if k.startswith('database')} evaluate_cordis(Path(args.prediction_dir) / 'ground_truth.txt', Path(args.prediction_dir) / 'output.txt', args.database, connection_config) else: raise NotImplemented('Only Spider and CORDIS evaluation are implemented so far.') if __name__ == '__main__': main()