ValueNet4SPARQL / src / main.py
main.py
Raw
import json
import os

import torch
from pytictoc import TicToc
from tqdm import tqdm

from model.model import IRNet
from config import read_arguments_train, write_config_to_file
from data_loader import get_data_loader
from evaluation import evaluate, transform_to_sql_and_evaluate_with_spider
from intermediate_representation import semQL
from optimizer import build_optimizer_encoder
from spider import spider_utils
from training import train
from utils import setup_device, set_seed_everywhere, save_model, create_experiment_folder

# initialize experiment tracking @ Weights & Biases
import wandb
wandb.init(project="XXX", entity="XXX")


if __name__ == '__main__':
    args = read_arguments_train()

    # log hyperparameters to Weights & Biases
    wandb.config.update(args)

    experiment_name, output_path = create_experiment_folder(args.model_output_dir, args.exp_name)
    print("Run experiment '{}'".format(experiment_name))

    write_config_to_file(args, output_path)

    device, n_gpu = setup_device()
    set_seed_everywhere(args.seed, n_gpu)

    sql_data, table_data, val_sql_data, val_table_data = spider_utils.load_dataset(args.data_dir, use_small=args.toy)
    train_loader, dev_loader = get_data_loader(sql_data, val_sql_data, args.batch_size, True, False)

    grammar = semQL.Grammar()
    model = IRNet(args, device, grammar)
    model.to(device)

    # track the model
    wandb.watch(model, log='parameters')

    num_train_steps = len(train_loader) * args.num_epochs
    optimizer, scheduler = build_optimizer_encoder(model,
                                                   num_train_steps,
                                                   args.lr_transformer, args.lr_connection, args.lr_base,
                                                   args.scheduler_gamma)

    global_step = 0
    best_acc = 0.0

    print("Start training with {} epochs".format(args.num_epochs))
    t = TicToc()
    for epoch in tqdm(range(int(args.num_epochs))):
        sketch_loss_weight = 1 if epoch < args.loss_epoch_threshold else args.sketch_loss_weight

        t.tic()
        global_step = train(global_step,
                            train_loader,
                            table_data,
                            model,
                            optimizer,
                            args.clip_grad,
                            sketch_loss_weight=sketch_loss_weight)

        train_time = t.tocvalue()

        tqdm.write("Training of epoch {0} finished after {1:.2f} seconds. Evaluate now on the dev-set".format(epoch, train_time))
        with torch.no_grad():
            sketch_acc, acc, _, predictions = evaluate(model,
                                                       dev_loader,
                                                       table_data,
                                                       args.beam_size)

        with open(os.path.join(output_path, 'predictions_sem_ql.json'), 'w', encoding='utf-8') as f:
            json.dump(predictions, f, indent=2)

        eval_results_string = "Epoch: {}    Sketch-Accuracy: {}     Accuracy: {}".format(epoch + 1, sketch_acc, acc)
        tqdm.write(eval_results_string)

        if args.run_spider_evaluation_after_epoch:
            total_transformed, fail_transform, spider_eval_results = transform_to_sql_and_evaluate_with_spider(predictions,
                                                                                                               table_data,
                                                                                                               output_path,
                                                                                                               args.data_dir,
                                                                                                               epoch + 1)

            tqdm.write("Successfully transformed {} of {} from SemQL to SQL.".format(total_transformed - fail_transform, total_transformed))
            tqdm.write("Results from Spider-Evaluation:")
            for key, value in spider_eval_results.items():
                tqdm.write("{}: {}".format(key, value))

        if acc > best_acc:
            save_model(model, os.path.join(output_path))
            tqdm.write("Accuracy of this epoch ({}) is higher then the so far best accuracy ({}). Save model.".format(acc, best_acc))
            best_acc = acc

        with open(os.path.join(output_path, "eval_results.log"), "a+", encoding='utf-8') as writer:
            writer.write(eval_results_string + "\n")

        wandb.log({"Sketch-accuracy": sketch_acc, "accuracy": acc}, step=epoch + 1)

        scheduler.step()  # Update learning rate schedule