ValueNet4SPARQL / src / model / encoder / encoder.py
encoder.py
Raw
import torch
from more_itertools import flatten
from torch import nn, FloatTensor
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from transformers import BartModel, BartTokenizer, BartConfig

from model.encoder.input_features import encode_input


class TransformerEncoder(nn.Module):

    def __init__(self, pretrained_model, device, max_sequence_length, schema_embedding_size, decoder_hidden_size):
        super(TransformerEncoder, self).__init__()
        self.max_sequence_length = max_sequence_length
        self.device = device

        config_class, model_class, tokenizer_class = (BartConfig, BartModel, BartTokenizer)

        transformer_config: BartConfig = config_class.from_pretrained(pretrained_model)
        self.tokenizer: BartTokenizer = tokenizer_class.from_pretrained(pretrained_model)
        self.transformer_model: BartModel = model_class.from_pretrained(pretrained_model, config=transformer_config)

        self.pooling_head = PoolingHead(transformer_config)

        self.encoder_hidden_size = transformer_config.hidden_size

        # We don't wanna do basic tokenizing (so splitting up a sentence into tokens) as this is already done in pre-processing.
        # But we still wanna do the wordpiece-tokenizing.
        self.tokenizer.do_basic_tokenize = False

        self.linear_layer_dimension_reduction_question = nn.Linear(transformer_config.hidden_size, decoder_hidden_size)

        self.column_encoder = nn.LSTM(transformer_config.hidden_size, schema_embedding_size // 2, bidirectional=True, batch_first=True)
        self.table_encoder = nn.LSTM(transformer_config.hidden_size, schema_embedding_size // 2, bidirectional=True, batch_first=True)
        self.value_encoder = nn.LSTM(transformer_config.hidden_size, schema_embedding_size // 2, bidirectional=True, batch_first=True)

        print("Successfully loaded pre-trained transformer '{}'".format(pretrained_model))

    def forward(self, question_tokens, column_names, table_names, values):
        input_ids_tensor, attention_mask_tensor, input_lengths = encode_input(question_tokens,
                                                                              column_names,
                                                                              table_names,
                                                                              values,
                                                                              self.tokenizer,
                                                                              self.max_sequence_length,
                                                                              self.device)

        # while the "last_hidden-states" is one hidden state per input token, the pooler_output is the hidden state of the [CLS]-token, further processed.
        # See e.g. "BertModel" documentation for more information.

        outputs = self.transformer_model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
        last_hidden_states = outputs[0]

        pooling_output = self._pool_output(last_hidden_states)

        (all_question_span_lengths, all_column_token_lengths, all_table_token_lengths, all_value_token_lengths) = input_lengths

        # we get the relevant hidden states for the question-tokens and average, if there are multiple token per word (e.g ['table', 'college'])
        averaged_hidden_states_question, pointers_after_question = self._average_hidden_states_question(last_hidden_states, all_question_span_lengths)
        question_out = pad_sequence(averaged_hidden_states_question, batch_first=True)  # (batch_size * max_question_tokens_per_batch * hidden_size)
        # as the transformer uses normally a size of 768 and the decoder only 300 per vector, we need to reduce dimensionality here with a linear layer.
        question_out = self.linear_layer_dimension_reduction_question(question_out)

        column_hidden_states, pointers_after_columns = self._get_schema_hidden_states(last_hidden_states, all_column_token_lengths, pointers_after_question)
        table_hidden_states, pointers_after_tables = self._get_schema_hidden_states(last_hidden_states, all_table_token_lengths, pointers_after_columns)

        # in this scenario, we know the values upfront and encode them similar to tables/columns.
        value_hidden_states, pointers_after_values = self._get_schema_hidden_states(last_hidden_states, all_value_token_lengths, pointers_after_tables)

        # This is simply to make sure the rather complex token-concatenation happens correctly. Can get removed at some point.
        self._assert_all_elements_processed(all_question_span_lengths,
                                            all_column_token_lengths,
                                            all_table_token_lengths,
                                            all_value_token_lengths,
                                            pointers_after_values,
                                            last_hidden_states.shape[1])

        # "column_hidden_states" (and table_hidden_states/value_hidden_states) is here a list of examples, with each example a list of tensors (one tensor for each column). As a column can have multiple words, the tensor consists of multiple columns (e.g. 3 * 768)
        # With this line we first concat all examples to one huge list of tensors, independent of the example. Remember: we don't wanna use an RNN over a full example - but only over the tokens of ONE column! Therefore we can just build up a batch of each column - tensor.
        # With "pad_sequence" we pay attention to the fact that each column can have a different amount of tokens (e.g. a 3-word column vs. a 1 word column), so we have to pad the shorter inputs.
        column_hidden_states_padded = pad_sequence(list(flatten(column_hidden_states)), batch_first=True)
        column_lengths = [len(t) for t in flatten(column_hidden_states)]

        table_hidden_states_padded = pad_sequence(list(flatten(table_hidden_states)), batch_first=True)
        table_lengths = [len(t) for t in flatten(table_hidden_states)]

        # create one embedding for each column by using an RNN.
        _, column_last_states, _ = self._rnn_wrapper(self.column_encoder, column_hidden_states_padded, column_lengths)

        # create one embedding for each table by using an RNN.
        _, table_last_states, _ = self._rnn_wrapper(self.table_encoder, table_hidden_states_padded, table_lengths)

        assert column_last_states.shape[0] == sum(map(lambda l: len(l), column_hidden_states))
        assert table_last_states.shape[0] == sum(map(lambda l: len(l), table_hidden_states))

        column_out = self._back_to_original_size(column_last_states, column_hidden_states)
        column_out_padded = pad_sequence(column_out, batch_first=True)

        table_out = self._back_to_original_size(table_last_states, table_hidden_states)
        table_out_padded = pad_sequence(table_out, batch_first=True)

        # in contrary to columns/tables there can be no values in a batch. In that case, return an empty tensor.
        if list(flatten(value_hidden_states)):
            value_hidden_states_padded = pad_sequence(list(flatten(value_hidden_states)), batch_first=True)
            value_lengths = [len(t) for t in flatten(value_hidden_states)]

            # create one embedding for each value by using an RNN.
            _, value_last_states, _ = self._rnn_wrapper(self.value_encoder, value_hidden_states_padded, value_lengths)

            assert value_last_states.shape[0] == sum(map(lambda l: len(l), value_hidden_states))

            value_out = self._back_to_original_size(value_last_states, value_hidden_states)
            value_out_padded = pad_sequence(value_out, batch_first=True)
        else:
            value_out_padded = torch.zeros(table_out_padded.shape[0], 0, table_out_padded.shape[2]).to(self.device)

        # we need the information of how many tokens are question tokens to later create the mask when calculating
        # attention over schema
        question_token_lengths = [sum(question_lengths) for question_lengths in all_question_span_lengths]

        return question_out, column_out_padded, table_out_padded, value_out_padded, pooling_output, question_token_lengths

    @staticmethod
    def _average_hidden_states_question(last_hidden_states, all_question_span_lengths):
        """
        NOTE: Keep in mind that we might soon skip this whole step, as averaging is not used right now - question token length is always 0.
        As described in the IRNet-paper, we will just average over the sub-tokens of a question-span.
        """
        all_averaged_hidden_states = []
        last_pointers = []

        for batch_itr_idx, question_span_lengths in enumerate(all_question_span_lengths):
            pointer = 0
            averaged_hidden_states = []

            for idx in range(0, len(question_span_lengths)):
                span_length = question_span_lengths[idx]

                averaged_span = torch.mean(last_hidden_states[batch_itr_idx, pointer: pointer + span_length, :],
                                           keepdim=True, dim=0)
                averaged_hidden_states.append(averaged_span)
                pointer += span_length

            all_averaged_hidden_states.append(torch.cat(averaged_hidden_states, dim=0))
            last_pointers.append(pointer)

        return all_averaged_hidden_states, last_pointers

    @staticmethod
    def _get_schema_hidden_states(last_hidden_states, all_schema_token_lengths, initial_pointers):
        """
        We simply put together the tokens for each column/table and filter out the separators. No averaging or concatenation, as we will use an RNN later
        """
        all_schema_hidden_state = []
        last_pointers = []

        for batch_itr_idx, (schema_token_lengths, initial_pointer) in enumerate(zip(all_schema_token_lengths, initial_pointers)):
            hidden_states_schema = []
            pointer = initial_pointer
            for schema_token_length in schema_token_lengths:
                # the -1 represents the [SEP] by the end of the column, which we don't wanna include.
                hidden_states_schema.append(last_hidden_states[batch_itr_idx, pointer: pointer + schema_token_length - 1, :])
                pointer += schema_token_length

            all_schema_hidden_state.append(hidden_states_schema)
            last_pointers.append(pointer)

        return all_schema_hidden_state, last_pointers

    def _rnn_wrapper(self, encoder, inputs, lengths):
        """
        This function abstracts from the technical details of the RNN. It handles the whole packing/unpacking of the values,
        handling zero-values and concatenating hidden/cell-states.
        """
        lengths = torch.tensor(lengths).to(self.device)

        # we need to sort the inputs by length due to the use of "pack_padded_sequence" which expects a sorted input.
        sorted_lens, sort_key = torch.sort(lengths, descending=True)
        # we remove temporally remove empty inputs
        nonzero_index = torch.sum(sorted_lens > 0).item()
        sorted_inputs = torch.index_select(inputs, dim=0, index=sort_key[:nonzero_index])

        # Even though we already padded inputs before "_rnn_wrapper", we still  need to "pack/unpack" the sequences.
        # Reason is mostly performance wise, read here: https://stackoverflow.com/a/55805785/1081551
        packed_inputs = pack_padded_sequence(sorted_inputs, sorted_lens[:nonzero_index].tolist(), batch_first=True)

        # forward it to the encoder network
        packed_out, (h, c) = encoder(packed_inputs)
        # unpack afterwards
        out, _ = pad_packed_sequence(packed_out, batch_first=True)

        # output dimensions:
        # out: (batch_size * max_sequence_length [padded] * dim). Example: (20 * 3 * 768)
        # h: (uni/bi-directional LSTM * batch_size * hidden_size). Example: (2 * 20 * 150). So 2 values per sequence, and the 2 is because we use bi-directional LSTM's
        # c: (uni/bi-directional LSTM * batch_size * hidden_size). Example: (2 * 20 * 150). So 2 values per sequence, and the 2 is because we use bi-directional LSTM's

        # as we remove zero-length inputs before, we need to extend the results here by the zero inputs
        # we do it for output
        pad_zeros = torch.zeros(sorted_lens.size(0) - out.size(0), out.size(1), out.size(2)).type_as(out).to(out.device)
        sorted_out = torch.cat([out, pad_zeros], dim=0)

        # and hidden state
        pad_hiddens = torch.zeros(h.size(0), sorted_lens.size(0) - h.size(1), h.size(2)).type_as(h).to(h.device)
        sorted_hiddens = torch.cat([h, pad_hiddens], dim=1)

        # and cell state
        pad_cells = torch.zeros(c.size(0), sorted_lens.size(0) - c.size(1), c.size(2)).type_as(c).to(c.device)
        sorted_cells = torch.cat([c, pad_cells], dim=1)

        # remember that sorted above and ranked by length? Here we need to invert this sorting to return in the same
        # order as the input was!
        shape = list(sorted_out.size())
        out = torch.zeros_like(sorted_out).type_as(sorted_out).to(sorted_out.device).scatter_(0, sort_key.unsqueeze(
            -1).unsqueeze(-1).expand(*shape), sorted_out)

        shape = list(sorted_hiddens.size())
        hiddens = torch.zeros_like(sorted_hiddens).type_as(sorted_hiddens).to(sorted_hiddens.device).scatter_(1,
                                                                                                              sort_key.unsqueeze(
                                                                                                                  0).unsqueeze(
                                                                                                                  -1).expand(
                                                                                                                  *shape),
                                                                                                              sorted_hiddens)

        cells = torch.zeros_like(sorted_cells).type_as(sorted_cells).to(sorted_cells.device).scatter_(1,
                                                                                                      sort_key.unsqueeze(
                                                                                                          0).unsqueeze(
                                                                                                          -1).expand(
                                                                                                          *shape),
                                                                                                      sorted_cells)

        # contiguous/non-contiguous seems to be a memory implementation detail to me...
        # look at this for more details: https://discuss.pytorch.org/t/contigious-vs-non-contigious-tensor/30107/2
        hiddens = hiddens.contiguous()
        cells = cells.contiguous()

        # here we concat the two hidden states/cell states of the Bi-directional LSTM
        hiddens_concated = torch.cat([hiddens[0], hiddens[1]], -1)
        cells_concated = torch.cat([cells[0], cells[1]], -1)

        return out, hiddens_concated, cells_concated

    @staticmethod
    def _back_to_original_size(elements_to_split, original_array):
        original_split = []

        dimensions = map(lambda l: len(l), original_array)

        current_idx = 0
        for length in dimensions:
            original_split.append(elements_to_split[current_idx:current_idx + length])
            current_idx += length

        assert elements_to_split.shape[0] == current_idx

        return original_split

    def _pool_output(self, last_hidden_states: FloatTensor):
        """
        In contrary to BERT, BART does not come with a "pooler_output" out of the box. We therefore use the same mechanism as
        BART uses for classification - a BartClassificationHead with the right configuration.
        @param last_hidden_states:
        """

        pooling_output = self.pooling_head(last_hidden_states)
        return pooling_output

    @staticmethod
    def _assert_all_elements_processed(all_question_span_lengths, all_column_token_lengths, all_table_token_lengths, all_value_token_lengths, last_pointers, len_last_hidden_states):

        # the longest element in the batch will decide how large the sequence is - therefore the max. pointer is the size of the hidden states.
        assert max(last_pointers) == len_last_hidden_states

        for question_span_lengths, column_token_lengths, table_token_lengths, value_token_length, last_pointer in zip(all_question_span_lengths, all_column_token_lengths, all_table_token_lengths, all_value_token_lengths, last_pointers):
            assert sum(question_span_lengths) + sum(column_token_lengths) + sum(table_token_lengths) + sum(value_token_length) == last_pointer


class PoolingHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dense(x)
        x = torch.tanh(x)

        return x