neural-question-generator / bidaf_lstm.py
bidaf_lstm.py
Raw
import torch
import torch.nn as nn
import constants

class LSTMContextLayer(nn.Module):
    """Extracts contextual information from inputs"""
    def __init__(self, embedding_dim, hidden_size, num_layers):
        super().__init__()
        
        self.pass_context = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.ans_context = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, bidirectional=True)

    def forward(self, passage, answer):
        """
        passage - N x P x E
        answer - N X A X E
        """
        out_pass, _ = self.pass_context(passage) # N x P x 2*H
        out_ans, _ = self.ans_context(answer) # N x A x 2*H
        return out_pass, out_ans

    def __call__(self, passage, answer):
        return self.forward(passage, answer)

class AttentionFlowLayer(nn.Module):
    """
    Computes C2Q and Q2C as per
    https://arxiv.org/pdf/1611.01603.pdf
    https://towardsdatascience.com/the-definitive-guide-to-bidaf-part-2-word-embedding-character-embedding-and-contextual-c151fc4f05bb
    https://towardsdatascience.com/the-definitive-guide-to-bidaf-part-3-attention-92352bbdcb07
    """
    def __init__(self, embedding_dim, hidden_size):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size

        self.weights = nn.Linear(6*self.hidden_size, 1, bias=False)

    def forward(self, H, U):
        """
        H: N x P x 2d -- passage representation
        U: N x A x 2d -- answer representation

        G (output): N x P x 8d
        Implementation based off of
        https://github.com/jojonki/BiDAF/blob/master/layers/bidaf.py
        """
        context = H.unsqueeze(2) # N x P x 1 x 2d
        ans = U.unsqueeze(1) # N x 1 x A x 2d

        cast_shape = (H.shape[0], H.shape[1], U.shape[1], 2*self.hidden_size) # N x P x A x 2d
        context = context.expand(cast_shape)
        ans = ans.expand(cast_shape)
        
        # Similarity Matrix Passage and Answer
        prod = torch.mul(context, ans)
        vec = torch.cat([context, ans, prod], axis=3) # N x P x A x 6d
        sim = self.weights(vec).squeeze() # N x P x A

        # C2Q - which query (ie. answer) words matter most to each context (ie. passage) word
        a = nn.Softmax(dim=2)(sim) # N x P x A
        U_tilda = torch.bmm(a, U) # N x P x 2d

        # Q2C
        b = torch.max(sim, dim=2)[0].unsqueeze(1) # N x 1 x P
        H_tilda = torch.bmm(b, H).tile((1, H.shape[1], 1)) # N x P x 2d

        # Merge to form G
        G = torch.cat([H, U_tilda, torch.mul(H, U_tilda), torch.mul(H, H_tilda)], dim=2) # N x P x 8d
        return G

    def __call__(self, H, U):
        return self.forward(H, U)

class AttentionFlowLSTMEncoder(nn.Module):

    def __init__(self, embedding_dim, hidden_size, num_layers, vocab):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab = vocab
        self.context_layer = LSTMContextLayer(self.embedding_dim, self.hidden_size, self.num_layers)
        self.attention_flow_layer = AttentionFlowLayer(self.embedding_dim, self.hidden_size)

        self.encoder = nn.LSTM(8*self.hidden_size, self.hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*self.hidden_size, self.embedding_dim)

    def forward(self, passage, answer):
        """
        passage: embedded passage N x P x E
        answer: embedded answer N x A x E

        encoding: encoded representation of passage-answer N x 1 x E
        """

        H, U = self.context_layer(passage, answer) # N x P x 2H, N x A x 2H
        G = self.attention_flow_layer(H, U) # N x P x 8H
        encoding, _ = self.encoder(G) # N x P x 2H
        encoding = torch.mean(encoding, dim=1, keepdim=True) # N x 1 x 2H
        encoding = self.fc(encoding) # N x 1 x E
        return encoding

    def __call__(self, passage, answer):
        return self.forward(passage, answer)

class LSTMDecoder(nn.Module):
    """Decoder to produce sequential output as question"""

    def __init__(self, embedding_dim, hidden_size, num_layers, vocab, question_length):
        """
        embedding_dim: dimension of word embedding
        hidden_size: hidden state dimension for decoder
        """
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab = vocab
        self.question_length = question_length

        self.decoder = nn.LSTM(self.embedding_dim, self.hidden_size, self.num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, len(self.vocab))
        self.softmax = nn.Softmax(dim=2)

    def forward(self, encoded_inputs, embedded_targets=None, embedder=None, temperature=1.0):
        """
        Generates sequential output
        
        encoded_inputs: N x 1 x E
        embedded_targets: N x Q x E
        embedded_targets: supply targets for teacher forcing, otherwise, a single output is produced
        """
        
        if embedded_targets is not None:        
            assert embedded_targets.shape[1] == self.question_length
            decoder_inputs = torch.cat([encoded_inputs, embedded_targets[:, :-1, :]], dim=1)
            sequence, _ = self.decoder(decoder_inputs) # N x Q x hidden_size
            sequence = self.fc(sequence) # N x Q x vocab_size
            return sequence
        else:
            with torch.no_grad():
                assert embedder is not None
                sequence = None
                hidden_state = None
                inputs = encoded_inputs
                for i in range(self.question_length):
                    out, hidden_state = self.decoder(inputs, hidden_state) # N x 1 x H
                    out = self.fc(out) # N x 1 x vocab_size
                    probs = self.softmax(out.div(temperature))
                    word = torch.multinomial(probs.squeeze(), 1) # N x 1

                    if i == 0:
                        sequence = word
                    else:
                        sequence = torch.cat([sequence, word], dim=1)

                    inputs = embedder(word.long()) # N x 1 x E
                    
                return sequence # N x Q
    
    def __call__(self, encoded_inputs, embedded_targets, embedder=None, temperature=1.0):
        
        return self.forward(encoded_inputs, embedded_targets, embedder, temperature)

class BiDAF_LSTMNet(nn.Module):

    def __init__(self, embedding_dim, hidden_size, num_layers, vocab, question_length, temperature=1.0):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab = vocab
        self.question_length = question_length
        self.temperature = temperature
        
        self.embedder = nn.Embedding(num_embeddings = len(self.vocab), embedding_dim=embedding_dim, padding_idx=0)
        self.encoder = AttentionFlowLSTMEncoder(self.embedding_dim, self.hidden_size, self.num_layers, self.vocab)
        self.decoder = LSTMDecoder(self.embedding_dim, self.hidden_size, self.num_layers, self.vocab, self.question_length)
        self.softmax = nn.Softmax(dim=2)

    def embed(self, words):
        """words is N x L"""
        return self.embedder(words) # N X L X embedding

    def forward(self, passage, answer, question):
        """
        passage: N x P
        answer: N x A
        question: N x Q
        """
        P = passage.shape[1] 
        A = answer.shape[1]
        Q = question.shape[1]
        
        assert P == constants.MAX_PASSAGE_LEN + 2
        assert A == constants.MAX_ANSWER_LEN + 2
        assert Q == constants.MAX_QUESTION_LEN + 2
        
        pass_ans_qembed = self.embed(torch.cat([passage, answer, question], dim=1))
        passage, answer, q_embed = pass_ans_qembed[:, :P, :], pass_ans_qembed[:, P:P+A, :], pass_ans_qembed[:, P+A:, :]

        encoded = self.encoder(passage, answer)
        sequence = self.decoder(encoded, q_embed) # teacher forced

        return sequence # N x Q x vocab_size

    def predict(self, passage, answer):
        """
        Generates question word-by-word

        Should only be used in evaluation mode
        """
        with torch.no_grad():
            pass_ans = self.embed(torch.cat([passage, answer], dim=1))
            passage, answer = pass_ans[:, :passage.shape[1], :], pass_ans[:, passage.shape[1]:, :]
            encoded = self.encoder(passage, answer) # N x 1 x E
            sequence = self.decoder(encoded, embedded_targets=None, embedder=self.embedder, temperature=self.temperature) # N x Q
            return sequence