import torch as t
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Normal, Bernoulli
QUERY_DIM=32
class Linear(nn.Module):
"""
Linear Module
"""
def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
"""
:param in_dim: dimension of input
:param out_dim: dimension of output
:param bias: boolean. if True, bias is included.
:param w_init: str. weight inits with xavier initialization.
"""
super(Linear, self).__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=nn.init.calculate_gain(w_init))
def forward(self, x):
return self.linear_layer(x)
class LatentEncoder(nn.Module):
"""
Latent Encoder [For prior, posterior]
"""
def __init__(self, num_hidden, num_latent, input_dim=QUERY_DIM+1):
super(LatentEncoder, self).__init__()
self.input_projection_e = Linear(input_dim, num_hidden)
self.input_projection_l = Linear(input_dim, num_hidden)
# self.self_attentions_e = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
# self.self_attentions_l = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
self.penultimate_layer = Linear(num_hidden, num_hidden, w_init='relu')
self.mu = Linear(num_hidden, num_latent)
self.log_sigma = Linear(num_hidden, num_latent)
def forward(self, xl, xe, yl, ye):
# concat location (x) and value (y)
encoder_input_e = t.cat([xe, ye], dim=-1)
encoder_input_l = t.cat([xl, yl], dim=-1)
# project vector with dimension 3 --> num_hidden
encoder_input_e = self.input_projection_e(encoder_input_e)
encoder_input_l = self.input_projection_l(encoder_input_l)
# self attention layer
# for attention in self.self_attentions:
# encoder_input, _ = attention(encoder_input, encoder_input, encoder_input)
# mean
hidden_e = encoder_input_e.mean(dim=0)
hidden_l = encoder_input_l.mean(dim=0)
hidden_g = t.relu((hidden_l + hidden_e)/2)
hidden_l = t.tanh(hidden_l) + t.sigmoid(hidden_g) * hidden_g
hidden_e = t.tanh(hidden_e) + t.sigmoid(hidden_g) * hidden_e
hidden_e = t.relu(self.penultimate_layer(hidden_e))
hidden_l = t.relu(self.penultimate_layer(hidden_l))
# get mu and sigma
mu = self.mu(hidden_g)
log_sigma = self.log_sigma(hidden_g)
# reparameterization trick
std = t.exp(0.5 * log_sigma)
eps = t.randn_like(std)
z = eps.mul(std).add_(mu)
# return distribution
return mu, log_sigma, z, hidden_l, hidden_e
class DeterministicEncoder(nn.Module):
"""
Deterministic Encoder [r]
"""
def __init__(self, num_hidden, num_latent, input_dim=QUERY_DIM+1):
super(DeterministicEncoder, self).__init__()
# self.self_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
self.cross_attentions_l = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
self.cross_attentions_e = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
self.input_projection = Linear(input_dim, num_hidden)
self.context_projection = Linear(QUERY_DIM, num_hidden)
self.target_projection = Linear(QUERY_DIM, num_hidden)
def forward(self, context_xl, context_xe, context_yl, context_ye, target_xl, target_xe):
# concat context location (x), context value (y)
encoder_input_l = t.cat([context_xl,context_yl], dim=-1)
encoder_input_e = t.cat([context_xe,context_ye], dim=-1)
# project vector with dimension 3 --> num_hidden
encoder_input_l = self.input_projection(encoder_input_l)
encoder_input_e = self.input_projection(encoder_input_e)
# self attention layer
# for attention in self.self_attentions:
# encoder_input, _ = attention(encoder_input, encoder_input, encoder_input)
# query: target_x, key: context_x, value: representation
query_l = self.target_projection(target_xl)
keys_l = self.context_projection(context_xl)
query_e = self.target_projection(target_xe)
keys_e = self.context_projection(context_xe)
# cross attention layer
for attention in self.cross_attentions_l:
query_l, _ = attention(keys_l, encoder_input_l, query_l)
for attention in self.cross_attentions_e:
query_e, _ = attention(keys_e, encoder_input_e, query_e)
return query_l, query_e
class Decoder(nn.Module):
"""
Decoder for generation
"""
def __init__(self, num_hidden):
super(Decoder, self).__init__()
self.target_projection = Linear(QUERY_DIM, num_hidden)
self.linears = nn.ModuleList([Linear(num_hidden * 3, num_hidden * 3, w_init='relu') for _ in range(3)])
self.final_to_mu = Linear(num_hidden * 3, 1)
self.final_to_sigma = Linear(num_hidden * 3, 1)
def forward(self, r, z, target_x, is_bernoulli=False):
# project vector with dimension 2 --> num_hidden
target_x = self.target_projection(target_x)
# concat all vectors (r,z,target_x)
hidden = t.cat([t.cat([r,z], dim=-1), target_x], dim=-1)
# mlp layers
for linear in self.linears:
hidden = t.relu(linear(hidden))
# get mu and sigma
y_pred_mu = self.final_to_mu(hidden)
if is_bernoulli:
return Bernoulli(t.sigmoid(y_pred_mu))
y_pred_sigma = self.final_to_sigma(hidden)
y_pred_sigma = 0.1 + 0.9 * F.softplus(y_pred_sigma)
return Normal(y_pred_mu, y_pred_sigma)
class Attention(nn.Module):
"""
Attention Network
"""
def __init__(self, num_hidden, h=4):
"""
:param num_hidden: dimension of hidden
:param h: num of heads
"""
super(Attention, self).__init__()
self.multihead_attention = nn.MultiheadAttention(num_hidden, h, batch_first=True)
def forward(self, key, value, query):
result, attns = self.multihead_attention(query, key, value)
return result, attns