ValueNet4SPARQL / src / model / pointer_net.py
pointer_net.py
Raw
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# coding=utf8

import torch
import torch.nn as nn
import torch.nn.utils
from torch.nn import Parameter


class AuxiliaryPointerNet(nn.Module):

    def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
        super(AuxiliaryPointerNet, self).__init__()

        assert attention_type in ('affine', 'dot_prod')
        if attention_type == 'affine':
            self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
            self.auxiliary_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
        self.attention_type = attention_type

    def forward(self, src_encodings, src_context_encodings, src_token_mask, query_vec):
        """
        :param src_context_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
        :param src_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
        :param src_token_mask: Variable(batch_size, src_sent_len)
        :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
        :return: Variable(tgt_action_num, batch_size, src_sent_len)
        """

        # (batch_size, 1, src_sent_len, query_vec_size)
        encodings = src_encodings.clone()
        context_encodings = src_context_encodings.clone()
        if self.attention_type == 'affine':
            encodings = self.src_encoding_linear(src_encodings)
            context_encodings = self.auxiliary_encoding_linear(src_context_encodings)
        encodings = encodings.unsqueeze(1)
        context_encodings = context_encodings.unsqueeze(1)

        # (batch_size, tgt_action_num, query_vec_size, 1)
        q = query_vec.permute(1, 0, 2).unsqueeze(3)

        # (batch_size, tgt_action_num, src_sent_len)
        weights = torch.matmul(encodings, q).squeeze(3)
        context_weights = torch.matmul(context_encodings, q).squeeze(3)

        # (tgt_action_num, batch_size, src_sent_len)
        weights = weights.permute(1, 0, 2)
        context_weights = context_weights.permute(1, 0, 2)

        if src_token_mask is not None:
            # (tgt_action_num, batch_size, src_sent_len)
            src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
            weights.data.masked_fill_(src_token_mask.bool(), -float('inf'))
            context_weights.data.masked_fill_(src_token_mask.bool(), -float('inf'))

        sigma = 0.1
        return weights.squeeze(0) + sigma * context_weights.squeeze(0)


class PointerNet(nn.Module):
    def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
        super(PointerNet, self).__init__()

        assert attention_type in ('affine', 'dot_prod')
        if attention_type == 'affine':
            self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)

        self.attention_type = attention_type
        self.input_linear = nn.Linear(query_vec_size, query_vec_size)
        self.type_linear = nn.Linear(32, query_vec_size)
        self.V = Parameter(torch.FloatTensor(query_vec_size), requires_grad=True)
        self.tanh = nn.Tanh()
        self.context_linear = nn.Conv1d(src_encoding_size, query_vec_size, 1, 1)
        self.coverage_linear = nn.Conv1d(1, query_vec_size, 1, 1)


        nn.init.uniform_(self.V, -1, 1)

    def forward(self, src_encodings, src_token_mask, query_vec):
        """
        :param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
        :param src_token_mask: Variable(batch_size, src_sent_len)
        :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
        :return: Variable(tgt_action_num, batch_size, src_sent_len)
        """

        # (batch_size, 1, src_sent_len, query_vec_size)

        if self.attention_type == 'affine':
            src_encodings = self.src_encoding_linear(src_encodings)
        src_encodings = src_encodings.unsqueeze(1)

        # (batch_size, tgt_action_num, query_vec_size, 1)
        q = query_vec.permute(1, 0, 2).unsqueeze(3)

        weights = torch.matmul(src_encodings, q).squeeze(3)

        weights = weights.permute(1, 0, 2)

        if src_token_mask is not None:
            src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
            weights.data.masked_fill_(src_token_mask.bool(), -float('inf'))

        return weights.squeeze(0)