aqetuner / network.py
network.py
Raw
from module import *
# from encoder.tcnn import TCNNEncoder
from pair_encoder import AttnEncoder, TreeEncoder, NODE_DIM, HIDDEN_DIM
from torch.distributions import kl_divergence, Normal

class LatentModel(nn.Module):
    """
    Latent Model (Attentive Neural Process)
    """
    def __init__(self, num_hidden):
        super(LatentModel, self).__init__()
        # self.pair_encoder = TreeEncoder(16, 16)
        self.pair_encoder = AttnEncoder(NODE_DIM, 16)
        self.latent_encoder = LatentEncoder(num_hidden, num_hidden)
        self.deterministic_encoder = DeterministicEncoder(num_hidden, num_hidden)
        self.decoder_l = Decoder(num_hidden)
        self.decoder_e = Decoder(num_hidden)


    def forward(self, context_pairs, context_y, target_pairs, target_y=None):
        context_yl = context_y[:, 0].unsqueeze(-1)
        context_ye = context_y[:, 1].unsqueeze(-1)
        context_x = self.pair_encoder(context_pairs)
        target_x = self.pair_encoder(target_pairs)
        target_xl, target_xe = target_x, target_x
        cond = context_ye == 0.
        cond = cond.all(1)
        context_xl = context_x[cond, :]
        context_yl = context_yl[cond, :]
        context_xe = context_x
        num_targets = target_x.size(0)
        
        prior_mu, prior_var, z, vl, ve = self.latent_encoder(context_xl, context_xe, context_yl, context_ye)
        
        if target_y is not None:
            target_yl = target_y[:, 0].unsqueeze(-1)
            target_ye = target_y[:, 1].unsqueeze(-1)
            cond = target_ye == 0.
            cond = cond.all(1)
            target_xl = target_x[cond, :]
            target_yl = target_yl[cond, :]
            target_xe = target_x
            posterior_mu, posterior_var, z, vl, ve = self.latent_encoder(target_xl, target_xe, target_yl, target_ye)
        
        z = z.repeat(num_targets,1) # [B, T_target, H]
        rl, re = self.deterministic_encoder(context_xl, context_xe, context_yl, context_ye, target_xl, target_xe) # [B, T_target, H]
        
        # mu should be the prediction of target y
        vl = vl.repeat(target_xl.size(0), 1)
        ve = ve.repeat(target_xe.size(0), 1)
        y_pred_l = self.decoder_l(rl, vl, target_xl)
        y_pred_e = self.decoder_e(re, ve, target_xe, is_bernoulli=True)
        
        # For Training
        if target_y is not None:
            # get log probability
            log_likelihood = y_pred_l.log_prob(target_yl).mean(dim=0).sum()
            log_likelihood += y_pred_e.log_prob(target_ye).mean(dim=0).sum()
            
            # get KL divergence between prior and posterior
            kl = self.kl_div(prior_mu, prior_var, posterior_mu, posterior_var)
            
            # maximize prob and minimize KL divergence
            loss = -log_likelihood + kl
        # For Generation
        else:
            kl = None
            loss = None
        
        return y_pred_l, y_pred_e, loss
    
    def kl_div(self, prior_mu, prior_var, posterior_mu, posterior_var):
        kl_div = (t.exp(posterior_var) + (posterior_mu-prior_mu) ** 2) / t.exp(prior_var) - 1. + (prior_var - posterior_var)
        kl_div = 0.5 * kl_div.sum()
        return kl_div