mvq / models / mh_dropout / mhd_random_2d.py
mhd_random_2d.py
Raw
# ---------------------------------------------------------------
# Copyright (c) Cyber Security Research Centre Limited 2022.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from .mhd_helper import (
    get_reduce_fn, get_dist_loss,
    unflat_tensor, get_topk_batch,
    MLP2d, UpSampler2d
)

class MHDropoutNetRandom2D(nn.Module):
    def __init__(self,
        inp_dim, 
        hidden_dim, 
        out_dim, 
        decoder_cfg,
        num_latent_space=2,
        mask_type='spatial',
        dist_reduce='mean', 
        loss_reduce='mean', 
        loss_reduce_dims=[-3,-2,-1], 
        hypothese_dim=1, 
        hypothese_bsz=32,
        dist_loss='mse', 
        dropout_rate=0.5, 
        topk=64, 
        residual=True,
        keep_length=128,
        up_sample_ratio=4,
        use_mhd_mask=True, 
        debug=False,
        **kwargs):

        super().__init__()
        self.debug = debug
        self.use_mask = use_mhd_mask
        self.dtype_float = torch.float32 
        self.mask_type = mask_type
        self.num_latent_space = num_latent_space

        self.topk = topk
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.hypothese_dim = hypothese_dim
        self.hypothese_bsz = hypothese_bsz
        self.dropout_rate = dropout_rate
        self.keep_length = keep_length
        self.residual = residual

        self.dist_loss = get_dist_loss(dist_loss)(reduction='none') 
        self.loss_reduce_dims = loss_reduce_dims
        self.loss_reduce = get_reduce_fn(loss_reduce)
        self.dist_reduce = get_reduce_fn(dist_reduce)

        self.decoder = MLP2d(inp_dim, hidden_dim, out_dim, **decoder_cfg)

        self.up_sample_d = None
        if up_sample_ratio > 1 and num_latent_space == 2:
           self.up_sample_d = UpSampler2d(self.out_dim, self.out_dim, up_sample_ratio)

    def _no_mask_forward(self, c_base, inp_latents):
        diff_vector = self.decoder(inp_latents)

        if self.up_sample_d is not None:
            diff_vector = self.up_sample_d(diff_vector)
        y_hat = c_base + diff_vector

        mhd_dict = {}
        mhd_dict['bsz']         = 1
        mhd_dict['keep_length'] = 1
        mhd_dict['diff_vector'] = diff_vector
        mhd_dict['yw_hat']      = y_hat 
        mhd_dict['y_hat']       = y_hat 

        return mhd_dict

    def forward(self, z_top, z_bottom, hypothese_count, latent_tgt=None, split='train'):
        
        #Reshape for ddb layer
        inp_latents = z_top
        bsz = self.hypothese_bsz if latent_tgt is not None else hypothese_count

        if not self.use_mask:
            return self._no_mask_forward(z_bottom, inp_latents)

        # Prepare residual connection 
        if self.residual:
            z_shape = list(z_bottom.shape)
            repeat_frame = [1] * (len(z_shape) + 1)
            repeat_frame[self.hypothese_dim] = bsz
            expanded_c = z_bottom.unsqueeze(self.hypothese_dim).repeat(repeat_frame)
            
        num_batches = int(np.ceil(hypothese_count / bsz))
        keep_length = int(np.ceil(bsz / num_batches))

        if split == 'eval' and keep_length > 128:
            keep_length = self.keep_length
        
        batch_hypos = []
        diff_vec_container = []
        log_berns_container = []
        for _ in range(num_batches):
            diff_vector, log_berns = self._forward(inp_latents, hypothese_count=bsz) 

            #Flatten along batch
            if self.up_sample_d is not None:
                diff_vector = torch.flatten(diff_vector, start_dim=0, end_dim=1)
                diff_vector = self.up_sample_d(diff_vector)
                diff_vector = unflat_tensor(diff_vector, bsz)

            if self.residual:
                y_hat = expanded_c + diff_vector
            else:
                y_hat = diff_vector

            if latent_tgt is not None:

                dist_loss, topk_idx = self.get_dist_loss(y_hat, latent_tgt, topk=keep_length)

                ### Only keep the top x based on meta loss
                keep_idx = topk_idx[:,:keep_length]
                topk_hypos = get_topk_batch(keep_idx, y_hat)
                batch_hypos.append(topk_hypos)

                diff_vector = get_topk_batch(keep_idx, diff_vector)
                diff_vec_container.append(diff_vector)

                if log_berns is not None:
                    log_berns = get_topk_batch(keep_idx, log_berns)
                    log_berns_container.append(log_berns)
        
        mhd_dict = {}
        mhd_dict['bsz'] = bsz
        mhd_dict['keep_length'] = keep_length
        batch_log_berns = None
        if latent_tgt is not None:

            batch_hypos  = torch.cat(batch_hypos, dim=1)[:,:bsz,:]
            diff_vector  = torch.cat(diff_vec_container, dim=1)[:,:bsz,:]

            if len(log_berns_container):
                batch_log_berns = torch.cat(log_berns_container, dim=1)[:,:bsz,:]

            # Recombine into y_hat (final list)
            y_hat = batch_hypos

            #Get the winner out of the batched winners
            dist_loss, topk_idx = self.get_dist_loss(y_hat, latent_tgt, topk=keep_length)
            wta_loss, wta_idx   = self.get_wta_loss(dist_loss, topk_idx, batch_log_berns=batch_log_berns)

            mhd_dict['yw_hat']   = self.get_win_hypo(y_hat, wta_idx)
            #mhd_dict['yk_hat'] = self.ddb_layer.get_topk_hypo(y_hat, topk_idx)
            
            mhd_dict['wta_idx']  = wta_idx
            mhd_dict['topk_idx'] = topk_idx
            mhd_dict['wta_loss'] = wta_loss.mean()

        mhd_dict['diff_vector'] = diff_vector
        mhd_dict['y_hat']       = y_hat 
        return mhd_dict

    def _forward(self, x, y=None, wta_idx=None, wta_loss=None, hypothese_count=1, keep_idx=None, **kwargs):
        
        hypotheses = self.sample(x, hypothese_count=hypothese_count, keep_idx=keep_idx)
        
        if (y is None and wta_idx is None and wta_loss is None):
            return hypotheses, None
        
        if y is not None:
            wta_loss, topk_idx  = self.get_dist_loss(hypotheses, y)
            wta_idx = topk_idx[:,0]
   
        win_hypo = self.get_win_hypo(hypotheses, wta_idx)
            
        return win_hypo, wta_loss.mean(-1)
    
    def get_win_hypo(self, hypotheses, wta_idx):
        batch_list = torch.arange(hypotheses.size(0))
        winner_hypo = hypotheses[batch_list, wta_idx, :]
        return winner_hypo

    def get_topk_hypo(self, hypotheses, topk_idx):
        topk_hypos = [batch_hypo[idx, :] for batch_hypo, idx in zip(hypotheses, topk_idx)]
        topk_hypos = torch.stack(topk_hypos, dim=0)
        return topk_hypos
    
    def create_spatial_mask(self, x):
        m_shape = x.shape[:2]
        repeat_frame = [1,1, x.shape[-2], x.shape[-1]]
        m_prob = torch.ones(m_shape, device=x.device, requires_grad=x.requires_grad, dtype=self.dtype_float) * (1. - self.dropout_rate)
        m = torch.bernoulli(m_prob).unsqueeze(-1).unsqueeze(-1) #[bz, chn, 1, 1]
        m = m.repeat(repeat_frame)
        return m

    def create_sample_mask(self, x):
        m_shape = [x.shape[0], x.shape[-2], x.shape[-1]]
        repeat_frame = [1,x.shape[1], 1, 1]
        m_prob = torch.ones(m_shape, device=x.device, requires_grad=x.requires_grad, dtype=self.dtype_float) * (1. - self.dropout_rate)
        m = torch.bernoulli(m_prob).unsqueeze(1) #[bz, 1, dim, dim]
        m = m.repeat(repeat_frame)
        return m

    def sample(self, x, hypothese_count, keep_idx=None):
        
        #Infer the repeat structure
        x_shape = list(x.shape)
        bs, x_dim, y_dim = x_shape[0], x_shape[-1], x_shape[-2]
        repeat_frame = [1] * (len(x_shape) + 1)
        repeat_frame[self.hypothese_dim] = hypothese_count

        #Repeat for single forward pass sampling
        x_repeat = x.unsqueeze(self.hypothese_dim).repeat(repeat_frame)
        
        #Flatten along batch and sample axis
        x_repeat = torch.flatten(x_repeat, start_dim=0, end_dim=1)

        # Apply dropout here
        if self.mask_type == 'spatial':
            mask = self.create_spatial_mask(x=x_repeat)
        elif self.mask_type == 'sample':
            mask = self.create_sample_mask(x=x_repeat)
        elif self.mask_type == 'and':
            spatial_mask = self.create_spatial_mask(x=x_repeat)
            sample_mask = self.create_sample_mask(x=x_repeat)
            mask = spatial_mask * sample_mask
        elif self.mask_type == 'or':
            spatial_mask = self.create_spatial_mask(x=x_repeat)
            sample_mask = self.create_sample_mask(x=x_repeat)
            mask = torch.clip(spatial_mask + sample_mask, max=1.0)
        else:
            raise ValueError("Invalid mask type given")


 
        distd_hidden = x_repeat * mask
        
        # [bs * hypothese_count, out_dim, x, y]
        output = self.decoder(distd_hidden)
        
        output = output.reshape((bs, hypothese_count, self.out_dim, x_dim, y_dim))

        return output

    def get_dist_loss(self, hypotheses, y, topk=None):

        hypothese_count = hypotheses.size(self.hypothese_dim)

        y_shape = [1] * len(hypotheses.shape)
        y_shape[self.hypothese_dim] = hypothese_count

        #Create copies of y to match sample length
        y_expanded = y.unsqueeze(self.hypothese_dim).repeat(y_shape)
        
        dist_loss = self.dist_loss(hypotheses, y_expanded)
        dist_loss = self.loss_reduce(dist_loss, dim=self.loss_reduce_dims)

        #Get the sample with the lowest meta loss
        if topk is None:
            topk = min(self.topk, dist_loss.size(1)) 
        
        topk_idx = torch.topk(dist_loss, dim=-1, largest=False, sorted=True, k=topk)[1]
        
        return dist_loss, topk_idx
        
    def get_wta_loss(self, dist_loss, topk_idx, **kwargs):
        wta_idx = topk_idx[:,0]
        #Create a mask based on the sample count
        wta_mask = F.one_hot(wta_idx, dist_loss.size(-1))
        
        #Get Winner-Takes-All Loss by using the mask
        wta_loss = (dist_loss * wta_mask).sum(-1)
        return wta_loss, wta_idx