nlql / seq2seq / utils / pipeline.py
pipeline.py
Raw
from dataclasses import dataclass
from typing import Union, List, Dict, Optional
from transformers.pipelines.text2text_generation import ReturnType, Text2TextGenerationPipeline
from transformers.tokenization_utils import TruncationStrategy
from transformers.tokenization_utils_base import BatchEncoding
from third_party.spider.preprocess.get_tables import dump_db_json_schema
from seq2seq.utils.dataset import serialize_schema
from seq2seq.utils.spider import spider_get_input
from seq2seq.utils.cosql import cosql_get_input


@dataclass
class Text2SQLInput(object):
    utterance: str
    db_id: str


class Text2SQLGenerationPipeline(Text2TextGenerationPipeline):
    """
    Pipeline for text-to-SQL generation using seq2seq models.

    The models that this pipeline can use are models that have been fine-tuned on the Spider text-to-SQL task.

    Usage::

        model = AutoModelForSeq2SeqLM.from_pretrained(...)
        tokenizer = AutoTokenizer.from_pretrained(...)
        db_path = ... path to "concert_singer" parent folder
        text2sql_generator = Text2SQLGenerationPipeline(
            model=model,
            tokenizer=tokenizer,
            db_path=db_path,
        )
        text2sql_generator(inputs=Text2SQLInput(utterance="How many singers do we have?", db_id="concert_singer"))
    """

    def __init__(self, *args, **kwargs):
        self.db_path: str = kwargs.pop("db_path")
        self.prefix: Optional[str] = kwargs.pop("prefix", None)
        self.normalize_query: bool = kwargs.pop("normalize_query", True)
        self.schema_serialization_type: str = kwargs.pop("schema_serialization_type", "peteshaw")
        self.schema_serialization_randomized: bool = kwargs.pop("schema_serialization_randomized", False)
        self.schema_serialization_with_db_id: bool = kwargs.pop("schema_serialization_with_db_id", True)
        self.schema_serialization_with_db_content: bool = kwargs.pop("schema_serialization_with_db_content", True)
        self.schema_cache: Dict[str, dict] = dict()
        super().__init__(*args, **kwargs)

    def __call__(self, inputs: Union[Text2SQLInput, List[Text2SQLInput]], *args, **kwargs):
        r"""
        Generate the output SQL expression(s) using text(s) given as inputs.

        Args:
            inputs (:obj:`Text2SQLInput` or :obj:`List[Text2SQLInput]`):
                Input text(s) for the encoder.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indices) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
                The truncation strategy for the tokenization within the pipeline.
                :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
                to truncate the input to fit the model's max_length instead of throwing an error down the line.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:

            - **generated_sql** (:obj:`str`, present when ``return_text=True``) -- The generated SQL.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated SQL.
        """
        result = super().__call__(inputs, *args, **kwargs)
        if (
            isinstance(inputs, list)
            and all(isinstance(el, Text2SQLInput) for el in inputs)
            and all(len(res) == 1 for res in result)
        ):
            return [res[0] for res in result]
        return result

    def preprocess(
        self,
        inputs: Union[Text2SQLInput, List[Text2SQLInput]],
        *args,
        truncation=TruncationStrategy.DO_NOT_TRUNCATE,
        **kwargs
    ) -> BatchEncoding:
        encodings = self._parse_and_tokenize(inputs, *args, truncation=truncation, **kwargs)
        return encodings

    def _parse_and_tokenize(
        self,
        inputs: Union[Text2SQLInput, List[Text2SQLInput]],
        *args,
        truncation: TruncationStrategy
    ) -> BatchEncoding:
        if isinstance(inputs, list):
            if self.tokenizer.pad_token_id is None:
                raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
            inputs = [self._pre_process(input=input) for input in inputs]
            padding = True
        elif isinstance(inputs, Text2SQLInput):
            inputs = self._pre_process(input=inputs)
            padding = False
        else:
            raise ValueError(
                f" `inputs`: {inputs} have the wrong format. The should be either of type `Text2SQLInput` or type `List[Text2SQLInput]`"
            )
        encodings = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
        # This is produced by tokenizers but is an invalid generate kwargs
        if "token_type_ids" in encodings:
            del encodings["token_type_ids"]
        return encodings

    def _pre_process(self, input: Text2SQLInput) -> str:
        prefix = self.prefix if self.prefix is not None else ""
        if input.db_id not in self.schema_cache:
            self.schema_cache[input.db_id] = get_schema(db_path=self.db_path, db_id=input.db_id)
        schema = self.schema_cache[input.db_id]
        if hasattr(self.model, "add_schema"):
            self.model.add_schema(db_id=input.db_id, db_info=schema)
        serialized_schema = serialize_schema(
            question=input.utterance,
            db_path=self.db_path,
            db_id=input.db_id,
            db_column_names=schema["db_column_names"],
            db_table_names=schema["db_table_names"],
            db_primary_keys=schema["db_primary_keys"],
            db_foreign_keys=schema["db_foreign_keys"],
            schema_serialization_type=self.schema_serialization_type,
            schema_serialization_randomized=self.schema_serialization_randomized,
            schema_serialization_with_db_id=self.schema_serialization_with_db_id,
            schema_serialization_with_db_content=self.schema_serialization_with_db_content,
            normalize_query=self.normalize_query,
        )
        return spider_get_input(question=input.utterance, serialized_schema=serialized_schema, prefix=prefix)

    def postprocess(self, model_outputs: dict, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
        records = []
        for output_ids in model_outputs["output_ids"][0]:
            if return_type == ReturnType.TENSORS:
                record = {f"{self.return_name}_token_ids": model_outputs}
            elif return_type == ReturnType.TEXT:
                record = {
                    f"{self.return_name}_text": self.tokenizer.decode(
                        output_ids,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                    .split("|", 1)[-1]
                    .strip()
                }
            records.append(record)
        return records


@dataclass
class ConversationalText2SQLInput(object):
    utterances: List[str]
    db_id: str


class ConversationalText2SQLGenerationPipeline(Text2TextGenerationPipeline):
    """
    Pipeline for conversational text-to-SQL generation using seq2seq models.

    The models that this pipeline can use are models that have been fine-tuned on the CoSQL SQL-grounded dialogue state tracking task.

    Usage::

        model = AutoModelForSeq2SeqLM.from_pretrained(...)
        tokenizer = AutoTokenizer.from_pretrained(...)
        db_path = ... path to "concert_singer" parent folder
        conv_text2sql_generator = ConversationalText2SQLGenerationPipeline(
            model=model,
            tokenizer=tokenizer,
            db_path=db_path,
        )
        conv_text2sql_generator(inputs=ConversationalText2SQLInput(utterances=["How many singers do we have?"], db_id="concert_singer"))
    """

    def __init__(self, *args, **kwargs):
        self.db_path: str = kwargs.pop("db_path")
        self.prefix: Optional[str] = kwargs.pop("prefix", None)
        self.normalize_query: bool = kwargs.pop("normalize_query", True)
        self.schema_serialization_type: str = kwargs.pop("schema_serialization_type", "peteshaw")
        self.schema_serialization_randomized: bool = kwargs.pop("schema_serialization_randomized", False)
        self.schema_serialization_with_db_id: bool = kwargs.pop("schema_serialization_with_db_id", True)
        self.schema_serialization_with_db_content: bool = kwargs.pop("schema_serialization_with_db_content", True)
        self.schema_cache: Dict[str, dict] = dict()
        super().__init__(*args, **kwargs)

    def __call__(self, inputs: Union[ConversationalText2SQLInput, List[ConversationalText2SQLInput]], *args, **kwargs):
        r"""
        Generate the output SQL expression(s) using conversation(s) given as inputs.

        Args:
            inputs (:obj:`ConversationalText2SQLInput` or :obj:`List[ConversationalText2SQLInput]`):
                Input conversation(s) for the encoder.
            return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to include the tensors of predictions (as token indices) in the outputs.
            return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to include the decoded texts in the outputs.
            clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to clean up the potential extra spaces in the text output.
            truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
                The truncation strategy for the tokenization within the pipeline.
                :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
                to truncate the input to fit the model's max_length instead of throwing an error down the line.
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
                corresponding to your framework `here <./model.html#generative-models>`__).

        Return:
            A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:

            - **generated_sql** (:obj:`str`, present when ``return_text=True``) -- The generated SQL.
            - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
              -- The token ids of the generated SQL.
        """
        result = super().__call__(inputs, *args, **kwargs)
        if (
            isinstance(inputs, list)
            and all(isinstance(el, ConversationalText2SQLInput) for el in inputs)
            and all(len(res) == 1 for res in result)
        ):
            return [res[0] for res in result]
        return result

    def preprocess(
        self,
        inputs: Union[ConversationalText2SQLInput, List[ConversationalText2SQLInput]],
        *args,
        truncation=TruncationStrategy.DO_NOT_TRUNCATE,
        **kwargs,
    ) -> BatchEncoding:
        encodings = self._parse_and_tokenize(inputs, *args, truncation=truncation, **kwargs)
        return encodings

    def _parse_and_tokenize(
        self,
        inputs: Union[ConversationalText2SQLInput, List[ConversationalText2SQLInput]],
        *args,
        truncation: TruncationStrategy,
    ) -> BatchEncoding:
        if isinstance(inputs, list):
            if self.tokenizer.pad_token_id is None:
                raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
            inputs = [self._pre_process(input=input) for input in inputs]
            padding = True
        elif isinstance(inputs, ConversationalText2SQLInput):
            inputs = self._pre_process(input=inputs)
            padding = False
        else:
            raise ValueError(
                f" `inputs`: {inputs} have the wrong format. The should be either of type `ConversationalText2SQLInput` or type `List[ConversationalText2SQLInput]`"
            )
        encodings = self.tokenizer(inputs, padding=padding, truncation=truncation, return_tensors=self.framework)
        # This is produced by tokenizers but is an invalid generate kwargs
        if "token_type_ids" in encodings:
            del encodings["token_type_ids"]
        return encodings

    def _pre_process(self, input: ConversationalText2SQLInput) -> str:
        prefix = self.prefix if self.prefix is not None else ""
        if input.db_id not in self.schema_cache:
            self.schema_cache[input.db_id] = get_schema(db_path=self.db_path, db_id=input.db_id)
        schema = self.schema_cache[input.db_id]
        if hasattr(self.model, "add_schema"):
            self.model.add_schema(db_id=input.db_id, db_info=schema)
        serialized_schema = serialize_schema(
            question=" | ".join(input.utterances),
            db_path=self.db_path,
            db_id=input.db_id,
            db_column_names=schema["db_column_names"],
            db_table_names=schema["db_table_names"],
            db_primary_keys=schema["db_primary_keys"],
            db_foreign_keys=schema["db_foreign_keys"],
            schema_serialization_type=self.schema_serialization_type,
            schema_serialization_randomized=self.schema_serialization_randomized,
            schema_serialization_with_db_id=self.schema_serialization_with_db_id,
            schema_serialization_with_db_content=self.schema_serialization_with_db_content,
            normalize_query=self.normalize_query,
        )
        return cosql_get_input(utterances=input.utterances, serialized_schema=serialized_schema, prefix=prefix)

    def postprocess(self, model_outputs: dict, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
        records = []
        for output_ids in model_outputs["output_ids"][0]:
            if return_type == ReturnType.TENSORS:
                record = {f"{self.return_name}_token_ids": model_outputs}
            elif return_type == ReturnType.TEXT:
                record = {
                    f"{self.return_name}_text": self.tokenizer.decode(
                        output_ids,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                    )
                    .split("|", 1)[-1]
                    .strip()
                }
            records.append(record)
        return records


def get_schema(db_path: str, db_id: str) -> dict:
    schema = dump_db_json_schema(db_path + "/" + db_id + "/" + db_id + ".sqlite", db_id)
    return {
        "db_table_names": schema["table_names_original"],
        "db_column_names": {
            "table_id": [table_id for table_id, _ in schema["column_names_original"]],
            "column_name": [column_name for _, column_name in schema["column_names_original"]],
        },
        "db_column_types": schema["column_types"],
        "db_primary_keys": {"column_id": [column_id for column_id in schema["primary_keys"]]},
        "db_foreign_keys": {
            "column_id": [column_id for column_id, _ in schema["foreign_keys"]],
            "other_column_id": [other_column_id for _, other_column_id in schema["foreign_keys"]],
        },
    }