import torch import torch.nn as nn import torchvision import constants class base_LSTM(nn.Module): def __init__(self, hidden_size, embedding_size, num_layers, vocab, model_temp): super().__init__() self.hidden_size = hidden_size self.embedding_size = embedding_size self.num_layers = num_layers self.vocab_size = len(vocab) self.model_temp = model_temp self.passage_length = MAX_PASSAGE_LEN+2 self.answer_length = MAX_ANSWER_LEN+2 self.question_length = MAX_QUESTION_LEN+2 self.encoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True, bidirectional=True) self.ffn = nn.Conv2d(in_channels=2*self.hidden_size, out_channels=self.embedding_size, kernel_size=1) self.word_embedding = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_size) self.pool = nn.AvgPool2d((1, self.passage_length+self.answer_length)) self.decoder = nn.LSTM(input_size=self.embedding_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True) self.fc = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size) def forward(self, passage, answer, question): linked_input = torch.cat((passage, answer), dim=1) linked_embedded = self.word_embedding(linked_input) encoded_inp = self.encoder(linked_embedded) temp = self.ffn(encoded_inp) inp_pa = self.pool(temp) inp_q = torch.split(question, [self.question_length-1, 1], dim=1)[0] inp = torch.cat((inp_pa, inp_q), dim=1) out = self.decoder(inp) out = self.fc(out) return out def predict(self, passage, answer, question_length): linked_input = torch.cat((passage, answer), dim=1) linked_embedded = self.word_embedding(linked_input) encoded_inp = self.encoder(linked_embedded) temp = self.ffn(encoded_inp) inp = self.pool(temp) hidden_state = None prediction = None for i in range(question_length): if hidden_state is None: out, hidden_state = self.decoder(inp) else: out, hidden_state = self.decoder(inp, hidden_state) out = self.fc(out) probs = nn.Softmax(dim=2)(out.div(self.model_temp)).squeeze() # N x vocab_size word = torch.multinomial(probs, 1) # N x 1 if i == 0: prediction = word else: prediction = torch.cat([prediction, word], dim=1) # N x L inp = self.embed_word(word.long()) # N x 1 x 300 return prediction def __call__(self, passage, answer, question): return self.forward(passage, answer, question)