from module import *
# from encoder.tcnn import TCNNEncoder
from pair_encoder import AttnEncoder, TreeEncoder, NODE_DIM, HIDDEN_DIM
from torch.distributions import kl_divergence, Normal
class LatentModel(nn.Module):
"""
Latent Model (Attentive Neural Process)
"""
def __init__(self, num_hidden):
super(LatentModel, self).__init__()
# self.pair_encoder = TreeEncoder(16, 16)
self.pair_encoder = AttnEncoder(NODE_DIM, 16)
self.latent_encoder = LatentEncoder(num_hidden, num_hidden)
self.deterministic_encoder = DeterministicEncoder(num_hidden, num_hidden)
self.decoder_l = Decoder(num_hidden)
self.decoder_e = Decoder(num_hidden)
def forward(self, context_pairs, context_y, target_pairs, target_y=None):
context_yl = context_y[:, 0].unsqueeze(-1)
context_ye = context_y[:, 1].unsqueeze(-1)
context_x = self.pair_encoder(context_pairs)
target_x = self.pair_encoder(target_pairs)
target_xl, target_xe = target_x, target_x
cond = context_ye == 0.
cond = cond.all(1)
context_xl = context_x[cond, :]
context_yl = context_yl[cond, :]
context_xe = context_x
num_targets = target_x.size(0)
prior_mu, prior_var, z, vl, ve = self.latent_encoder(context_xl, context_xe, context_yl, context_ye)
if target_y is not None:
target_yl = target_y[:, 0].unsqueeze(-1)
target_ye = target_y[:, 1].unsqueeze(-1)
cond = target_ye == 0.
cond = cond.all(1)
target_xl = target_x[cond, :]
target_yl = target_yl[cond, :]
target_xe = target_x
posterior_mu, posterior_var, z, vl, ve = self.latent_encoder(target_xl, target_xe, target_yl, target_ye)
z = z.repeat(num_targets,1) # [B, T_target, H]
rl, re = self.deterministic_encoder(context_xl, context_xe, context_yl, context_ye, target_xl, target_xe) # [B, T_target, H]
# mu should be the prediction of target y
vl = vl.repeat(target_xl.size(0), 1)
ve = ve.repeat(target_xe.size(0), 1)
y_pred_l = self.decoder_l(rl, vl, target_xl)
y_pred_e = self.decoder_e(re, ve, target_xe, is_bernoulli=True)
# For Training
if target_y is not None:
# get log probability
log_likelihood = y_pred_l.log_prob(target_yl).mean(dim=0).sum()
log_likelihood += y_pred_e.log_prob(target_ye).mean(dim=0).sum()
# get KL divergence between prior and posterior
kl = self.kl_div(prior_mu, prior_var, posterior_mu, posterior_var)
# maximize prob and minimize KL divergence
loss = -log_likelihood + kl
# For Generation
else:
kl = None
loss = None
return y_pred_l, y_pred_e, loss
def kl_div(self, prior_mu, prior_var, posterior_mu, posterior_var):
kl_div = (t.exp(posterior_var) + (posterior_mu-prior_mu) ** 2) / t.exp(prior_var) - 1. + (prior_var - posterior_var)
kl_div = 0.5 * kl_div.sum()
return kl_div