nlql / seq2seq / utils / dataset_loader.py
dataset_loader.py
Raw
import json
from typing import Callable, Tuple
import logging
import datasets.load
from datasets.dataset_dict import DatasetDict
from datasets.metric import Metric
from datasets.arrow_dataset import Dataset, concatenate_datasets
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.training_args import TrainingArguments
from seq2seq.utils.args import ModelArguments
from seq2seq.utils.dataset import (
    DataArguments,
    DataTrainingArguments,
    DatasetSplits,
    TrainSplit,
    _prepare_train_split,
    prepare_splits,
)
from seq2seq.utils.spider import spider_add_serialized_schema, spider_pre_process_function
from seq2seq.utils.cosql import cosql_add_serialized_schema, cosql_pre_process_function
from seq2seq.utils.worldcup import worldcup_add_serialized_schema, worldcup_pre_process_function

logger = logging.getLogger(__name__)


def _log_duplicate_count(dataset: Dataset, dataset_name: str, split: str) -> None:
    d = dataset.to_dict()
    d_t = [tuple((k, tuple(v)) for k, v in zip(d.keys(), vs)) for vs in zip(*d.values())]
    d_t_ = set(d_t)
    num_examples = len(d_t)
    duplicate_count = num_examples - len(d_t_)
    if duplicate_count > 0:
        logger.warning(
            f"The split ``{split}`` of the dataset ``{dataset_name}`` contains {duplicate_count} duplicates out of {num_examples} examples"
        )


def load_dataset(
    data_args: DataArguments,
    model_args: ModelArguments,
    data_training_args: DataTrainingArguments,
    training_args: TrainingArguments,
    tokenizer: PreTrainedTokenizerFast,
) -> Tuple[Metric, DatasetSplits]:
    _spider_dataset_dict: Callable[[], DatasetDict] = lambda: datasets.load.load_dataset(
        path=data_args.dataset_paths["spider"], cache_dir=model_args.cache_dir
    )
    _spider_metric: Callable[[], Metric] = lambda: datasets.load.load_metric(
        path=data_args.metric_paths["spider"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir
    )
    _spider_add_serialized_schema = lambda ex: spider_add_serialized_schema(
        ex=ex,
        data_training_args=data_training_args,
    )
    _spider_pre_process_function = lambda batch, max_source_length, max_target_length: spider_pre_process_function(
        batch=batch,
        max_source_length=max_source_length,
        max_target_length=max_target_length,
        data_training_args=data_training_args,
        tokenizer=tokenizer,
    )
    _worldcup_dataset_dict: Callable[[], DatasetDict] = lambda: datasets.load.load_dataset(
        path=data_args.dataset_paths["worldcup"], name=data_args.dataset, cache_dir=model_args.cache_dir
    )
    _worldcup_metric: Callable[[], Metric] = lambda: datasets.load.load_metric(
        path=data_args.metric_paths["worldcup"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir
    )
    _worldcup_add_serialized_schema = lambda ex: worldcup_add_serialized_schema(
        ex=ex,
        data_training_args=data_training_args,
    )
    _worldcup_pre_process_function = lambda batch, max_source_length, max_target_length: worldcup_pre_process_function(
        batch=batch,
        max_source_length=max_source_length,
        max_target_length=max_target_length,
        data_training_args=data_training_args,
        tokenizer=tokenizer,
    )

    _cosql_dataset_dict: Callable[[], DatasetDict] = lambda: datasets.load.load_dataset(
        path=data_args.dataset_paths["cosql"], cache_dir=model_args.cache_dir
    )
    _cosql_metric: Callable[[], Metric] = lambda: datasets.load.load_metric(
        path=data_args.metric_paths["cosql"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir
    )
    _cosql_add_serialized_schema = lambda ex: cosql_add_serialized_schema(
        ex=ex,
        data_training_args=data_training_args,
    )
    _cosql_pre_process_function = lambda batch, max_source_length, max_target_length: cosql_pre_process_function(
        batch=batch,
        max_source_length=max_source_length,
        max_target_length=max_target_length,
        data_training_args=data_training_args,
        tokenizer=tokenizer,
    )
    #adding spider_realistic dataset, metric, using schema and preprocess funtions of spider as it is
    _spider_realistic_dataset_dict : Callable[[], DatasetDict] = lambda: datasets.load.load_dataset(
        path=data_args.dataset_paths['spider_realistic'], cache_dir=model_args.cache_dir
    )
    _spider_realistic_metric: Callable[[], Metric] = lambda: datasets.load.load_metric(
        path=data_args.metric_paths["spider_realistic"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir
    )

    _spider_syn_dataset_dict : Callable[[], DatasetDict] = lambda: datasets.load.load_dataset(
        path=data_args.dataset_paths['spider_syn'], cache_dir=model_args.cache_dir
    )
    _spider_syn_metric: Callable[[], Metric] = lambda: datasets.load.load_metric(
        path=data_args.metric_paths["spider_syn"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir
    )

    _spider_dk_dataset_dict : Callable[[], DatasetDict] = lambda: datasets.load.load_dataset(
        path=data_args.dataset_paths['spider_dk'], cache_dir=model_args.cache_dir
    )
    _spider_dk_metric: Callable[[], Metric] = lambda: datasets.load.load_metric(
        path=data_args.metric_paths["spider_dk"], config_name=data_args.metric_config, test_suite_db_dir=data_args.test_suite_db_dir
    )

    _prepare_splits_kwargs = {
        "data_args": data_args,
        "training_args": training_args,
        "data_training_args": data_training_args,
    }

    if data_args.dataset == "spider":
        metric = _spider_metric()
        dataset_splits = prepare_splits(
            dataset_dict=_spider_dataset_dict(),
            add_serialized_schema=_spider_add_serialized_schema,
            pre_process_function=_spider_pre_process_function,
            **_prepare_splits_kwargs,
        )
    elif data_args.dataset == "cosql":
        metric = _cosql_metric()
        dataset_splits = prepare_splits(
            dataset_dict=_cosql_dataset_dict(),
            add_serialized_schema=_cosql_add_serialized_schema,
            pre_process_function=_cosql_pre_process_function,
            **_prepare_splits_kwargs,
        )
    elif data_args.dataset == "spider_realistic":
        metric = _spider_realistic_metric()
        dataset_splits = prepare_splits(
            dataset_dict= _spider_realistic_dataset_dict(),
            add_serialized_schema=_spider_add_serialized_schema,
            pre_process_function=_spider_pre_process_function,
            **_prepare_splits_kwargs,
        )
    elif data_args.dataset == "spider_dk":
        metric = _spider_dk_metric()
        dataset_splits = prepare_splits(
            dataset_dict= _spider_dk_dataset_dict(),
            add_serialized_schema=_spider_add_serialized_schema,
            pre_process_function=_spider_pre_process_function,
            **_prepare_splits_kwargs,
        )
    elif data_args.dataset == "spider_syn":
        metric = _spider_syn_metric()
        dataset_splits = prepare_splits(
            dataset_dict= _spider_syn_dataset_dict(),
            add_serialized_schema=_spider_add_serialized_schema,
            pre_process_function=_spider_pre_process_function,
            **_prepare_splits_kwargs,
        )
    elif data_args.dataset == "cosql+spider":
        metric = _cosql_metric()
        cosql_dataset_splits = prepare_splits(
            dataset_dict=_cosql_dataset_dict(),
            add_serialized_schema=_cosql_add_serialized_schema,
            pre_process_function=_cosql_pre_process_function,
            **_prepare_splits_kwargs,
        )
        spider_training_split = (
            _prepare_train_split(
                dataset=_spider_dataset_dict()["train"],
                data_training_args=data_training_args,
                add_serialized_schema=_spider_add_serialized_schema,
                pre_process_function=_spider_pre_process_function,
            )
            if training_args.do_train
            else None
        )
        if cosql_dataset_splits.train_split is None and spider_training_split is None:
            train_split = None
        elif cosql_dataset_splits.train_split is None:
            train_split = spider_training_split
        elif spider_training_split is None:
            train_split = cosql_dataset_splits.train_split
        else:
            dataset: Dataset = concatenate_datasets(
                dsets=[cosql_dataset_splits.train_split.dataset, spider_training_split.dataset]
            )
            train_split = TrainSplit(
                dataset=dataset,
                schemas={**spider_training_split.schemas, **cosql_dataset_splits.train_split.schemas},
            )
        schemas = {
            **cosql_dataset_splits.schemas,
            **(spider_training_split.schemas if spider_training_split is not None else {}),
        }
        dataset_splits = DatasetSplits(
            train_split=train_split,
            eval_split=cosql_dataset_splits.eval_split,
            test_splits=cosql_dataset_splits.test_splits,
            schemas=schemas,
        )
    elif data_args.dataset[:8] == "worldcup":
        metric = _worldcup_metric()
        dataset_splits = prepare_splits(
            dataset_dict=_worldcup_dataset_dict(),
            add_serialized_schema=_worldcup_add_serialized_schema,
            pre_process_function=_worldcup_pre_process_function,
            **_prepare_splits_kwargs,
        )
    else:
        raise NotImplementedError()

    if dataset_splits.train_split is not None:
        _log_duplicate_count(dataset=dataset_splits.train_split.dataset, dataset_name=data_args.dataset, split="train")
    if dataset_splits.eval_split is not None:
        _log_duplicate_count(dataset=dataset_splits.eval_split.dataset, dataset_name=data_args.dataset, split="eval")
    if dataset_splits.test_splits is not None:
        for section, split in dataset_splits.test_splits.items():
            _log_duplicate_count(dataset=split.dataset, dataset_name=data_args.dataset, split=section)

    return metric, dataset_splits