nlql / tests / test_dataset.py
test_dataset.py
Raw
from seq2seq.utils.picard_model_wrapper import get_picard_schema
import pytest
from dataclasses import replace
from transformers.models.auto import AutoTokenizer
from transformers.training_args import TrainingArguments
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from seq2seq.utils.args import ModelArguments
from seq2seq.utils.dataset import DataArguments, DataTrainingArguments
from seq2seq.utils.dataset_loader import load_dataset


@pytest.fixture
def training_args(tmpdir, split) -> TrainingArguments:
    if split == "train":
        do_train = True
        do_eval = False
    elif split == "eval":
        do_train = False
        do_eval = True
    else:
        raise NotImplementedError()
    return TrainingArguments(
        output_dir=str(tmpdir),
        do_train=do_train,
        do_eval=do_eval,
        do_predict=False,
    )


@pytest.fixture(params=[True, False])
def schema_serialization_with_db_content(request) -> bool:
    return request.param


@pytest.fixture
def data_training_args(schema_serialization_with_db_content: bool) -> DataTrainingArguments:
    return DataTrainingArguments(
        max_source_length=4096,
        max_target_length=4096,
        schema_serialization_type="peteshaw",
        schema_serialization_with_db_id=True,
        schema_serialization_with_db_content=schema_serialization_with_db_content,
        normalize_query=True,
        target_with_db_id=True,
    )


@pytest.fixture(params=["cosql", "spider", "cosql+spider"])
def data_args(request) -> DataArguments:
    return DataArguments(dataset=request.param)


@pytest.fixture(params=["train", "eval"])
def split(request) -> str:
    return request.param


@pytest.fixture
def expected_max_input_ids_len(data_args: DataArguments, split: str, schema_serialization_with_db_content: bool) -> int:
    def _expected_max_input_ids_len(_data_args: DataArguments) -> int:
        if _data_args.dataset == "spider" and split == "train" and schema_serialization_with_db_content:
            return 1927
        elif _data_args.dataset == "spider" and split == "train" and not schema_serialization_with_db_content:
            return 1892
        elif _data_args.dataset == "spider" and split == "eval" and schema_serialization_with_db_content:
            return 468
        elif _data_args.dataset == "spider" and split == "eval" and not schema_serialization_with_db_content:
            return 468
        elif _data_args.dataset == "cosql" and split == "train" and schema_serialization_with_db_content:
            return 2227
        elif _data_args.dataset == "cosql" and split == "train" and not schema_serialization_with_db_content:
            return 1984
        elif _data_args.dataset == "cosql" and split == "eval" and schema_serialization_with_db_content:
            return 545
        elif _data_args.dataset == "cosql" and split == "eval" and not schema_serialization_with_db_content:
            return 545
        elif _data_args.dataset == "cosql+spider":
            return max(
                _expected_max_input_ids_len(_data_args=replace(_data_args, dataset="spider")),
                _expected_max_input_ids_len(_data_args=replace(_data_args, dataset="cosql")),
            )
        else:
            raise NotImplementedError()

    return _expected_max_input_ids_len(_data_args=data_args)


@pytest.fixture
def expected_max_labels_len(data_args: DataArguments, split: str, schema_serialization_with_db_content: bool) -> int:
    def _expected_max_labels_len(_data_args: DataArguments) -> int:
        if _data_args.dataset == "spider" and split == "train" and schema_serialization_with_db_content:
            return 250
        elif _data_args.dataset == "spider" and split == "train" and not schema_serialization_with_db_content:
            return 250
        elif _data_args.dataset == "spider" and split == "eval" and schema_serialization_with_db_content:
            return 166
        elif _data_args.dataset == "spider" and split == "eval" and not schema_serialization_with_db_content:
            return 166
        elif _data_args.dataset == "cosql" and split == "train" and schema_serialization_with_db_content:
            return 250
        elif _data_args.dataset == "cosql" and split == "train" and not schema_serialization_with_db_content:
            return 250
        elif _data_args.dataset == "cosql" and split == "eval" and schema_serialization_with_db_content:
            return 210
        elif _data_args.dataset == "cosql" and split == "eval" and not schema_serialization_with_db_content:
            return 210
        elif _data_args.dataset == "cosql+spider":
            return max(
                _expected_max_labels_len(_data_args=replace(_data_args, dataset="spider")),
                _expected_max_labels_len(_data_args=replace(_data_args, dataset="cosql")),
            )
        else:
            raise NotImplementedError()

    return _expected_max_labels_len(_data_args=data_args)


@pytest.fixture
def model_name_or_path() -> str:
    return "t5-small"


@pytest.fixture
def model_args(model_name_or_path: str) -> ModelArguments:
    return ModelArguments(model_name_or_path=model_name_or_path)


@pytest.fixture
def tokenizer(model_args: ModelArguments) -> PreTrainedTokenizerFast:
    return AutoTokenizer.from_pretrained(model_args.model_name_or_path)


def test_dataset_loader(
    data_args: DataArguments,
    split: str,
    expected_max_input_ids_len: int,
    expected_max_labels_len: int,
    training_args: TrainingArguments,
    data_training_args: DataTrainingArguments,
    model_args: ModelArguments,
    tokenizer: PreTrainedTokenizerFast,
) -> None:

    _metric, dataset_splits = load_dataset(
        data_args=data_args,
        model_args=model_args,
        data_training_args=data_training_args,
        training_args=training_args,
        tokenizer=tokenizer,
    )

    if split == "train":
        dataset = dataset_splits.train_split.dataset
        schemas = dataset_splits.train_split.schemas
    elif split == "eval":
        dataset = dataset_splits.eval_split.dataset
        schemas = dataset_splits.eval_split.schemas
    elif split == "test":
        dataset = dataset_splits.test_split.dataset
        schemas = dataset_splits.test_split.schemas
    else:
        raise NotImplementedError()
    
    print(list(schemas.keys()))

    for db_id, schema in schemas.items():
        if db_id in ["academic", "scholar", "geo", "imdb", "yelp", "restaurants", "orchestra", "wta_1", "singer", "cre_Doc_Template_Mgt", "dog_kennels", "poker_player", "museum_visit", "student_transcripts_tracking"]:
            sql_schema = get_picard_schema(**schema)
            c_names_str = ", ".join((f"(\"{c_id}\", \"{c_name}\")" for c_id, c_name in sql_schema.columnNames.items()))
            c_type_str = lambda c_type: str(c_type).replace('.', "_")
            c_types_str = ", ".join((f"(\"{c_id}\", {c_type_str(c_type)})" for c_id, c_type in sql_schema.columnTypes.items()))
            t_names_str = ", ".join((f"(\"{t_id}\", \"{t_name}\")" for t_id, t_name in sql_schema.tableNames.items()))
            c_to_t_str = ", ".join((f"(\"{c_id}\", \"{t_id}\")" for c_id, t_id in sql_schema.columnToTable.items()))
            c_ids_str = lambda c_ids: ", ".join((f"\"{c_id}\"" for c_id in c_ids))
            t_to_cs_str = ", ".join((f"(\"{t_id}\", [{c_ids_str(c_ids)}])" for t_id, c_ids in sql_schema.tableToColumns.items()))
            fks_str = ", ".join((f"(\"{c_id}\", \"{other_cid}\")" for c_id, other_cid in sql_schema.foreignKeys.items()))
            pks_str = ", ".join((f"\"{pk}\"" for pk in sql_schema.primaryKeys))
            schema_str = "SQLSchema {sQLSchema_columnNames = columnNames, sQLSchema_columnTypes = columnTypes, sQLSchema_tableNames = tableNames, sQLSchema_columnToTable = columnToTable, sQLSchema_tableToColumns = tableToColumns, sQLSchema_foreignKeys = foreignKeys, sQLSchema_primaryKeys = primaryKeys}" 
            print(f"""
{db_id}:
  let columnNames = HashMap.fromList [{c_names_str}]
      columnTypes = HashMap.fromList [{c_types_str}]
      tableNames = HashMap.fromList [{t_names_str}]
      columnToTable = HashMap.fromList [{c_to_t_str}]
      tableToColumns = HashMap.fromList [{t_to_cs_str}]
      foreignKeys = HashMap.fromList [{fks_str}]
      primaryKeys = [{pks_str}]
   in {schema_str}
            """)

    max_input_ids_len = max(len(item["input_ids"]) for item in dataset)
    assert max_input_ids_len == expected_max_input_ids_len

    max_labels_len = max(len(item["labels"]) for item in dataset)
    assert max_labels_len == expected_max_labels_len