# --------------------------------------------------------------- # Copyright (c) Cyber Security Research Centre Limited 2022. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .mhd_helper import ( get_reduce_fn, get_dist_loss, unflat_tensor, get_topk_batch, MLP2d, UpSampler2d ) class MHDropoutNetRandom2D(nn.Module): def __init__(self, inp_dim, hidden_dim, out_dim, decoder_cfg, num_latent_space=2, mask_type='spatial', dist_reduce='mean', loss_reduce='mean', loss_reduce_dims=[-3,-2,-1], hypothese_dim=1, hypothese_bsz=32, dist_loss='mse', dropout_rate=0.5, topk=64, residual=True, keep_length=128, up_sample_ratio=4, use_mhd_mask=True, debug=False, **kwargs): super().__init__() self.debug = debug self.use_mask = use_mhd_mask self.dtype_float = torch.float32 self.mask_type = mask_type self.num_latent_space = num_latent_space self.topk = topk self.out_dim = out_dim self.hidden_dim = hidden_dim self.hypothese_dim = hypothese_dim self.hypothese_bsz = hypothese_bsz self.dropout_rate = dropout_rate self.keep_length = keep_length self.residual = residual self.dist_loss = get_dist_loss(dist_loss)(reduction='none') self.loss_reduce_dims = loss_reduce_dims self.loss_reduce = get_reduce_fn(loss_reduce) self.dist_reduce = get_reduce_fn(dist_reduce) self.decoder = MLP2d(inp_dim, hidden_dim, out_dim, **decoder_cfg) self.up_sample_d = None if up_sample_ratio > 1 and num_latent_space == 2: self.up_sample_d = UpSampler2d(self.out_dim, self.out_dim, up_sample_ratio) def _no_mask_forward(self, c_base, inp_latents): diff_vector = self.decoder(inp_latents) if self.up_sample_d is not None: diff_vector = self.up_sample_d(diff_vector) y_hat = c_base + diff_vector mhd_dict = {} mhd_dict['bsz'] = 1 mhd_dict['keep_length'] = 1 mhd_dict['diff_vector'] = diff_vector mhd_dict['yw_hat'] = y_hat mhd_dict['y_hat'] = y_hat return mhd_dict def forward(self, z_top, z_bottom, hypothese_count, latent_tgt=None, split='train'): #Reshape for ddb layer inp_latents = z_top bsz = self.hypothese_bsz if latent_tgt is not None else hypothese_count if not self.use_mask: return self._no_mask_forward(z_bottom, inp_latents) # Prepare residual connection if self.residual: z_shape = list(z_bottom.shape) repeat_frame = [1] * (len(z_shape) + 1) repeat_frame[self.hypothese_dim] = bsz expanded_c = z_bottom.unsqueeze(self.hypothese_dim).repeat(repeat_frame) num_batches = int(np.ceil(hypothese_count / bsz)) keep_length = int(np.ceil(bsz / num_batches)) if split == 'eval' and keep_length > 128: keep_length = self.keep_length batch_hypos = [] diff_vec_container = [] log_berns_container = [] for _ in range(num_batches): diff_vector, log_berns = self._forward(inp_latents, hypothese_count=bsz) #Flatten along batch if self.up_sample_d is not None: diff_vector = torch.flatten(diff_vector, start_dim=0, end_dim=1) diff_vector = self.up_sample_d(diff_vector) diff_vector = unflat_tensor(diff_vector, bsz) if self.residual: y_hat = expanded_c + diff_vector else: y_hat = diff_vector if latent_tgt is not None: dist_loss, topk_idx = self.get_dist_loss(y_hat, latent_tgt, topk=keep_length) ### Only keep the top x based on meta loss keep_idx = topk_idx[:,:keep_length] topk_hypos = get_topk_batch(keep_idx, y_hat) batch_hypos.append(topk_hypos) diff_vector = get_topk_batch(keep_idx, diff_vector) diff_vec_container.append(diff_vector) if log_berns is not None: log_berns = get_topk_batch(keep_idx, log_berns) log_berns_container.append(log_berns) mhd_dict = {} mhd_dict['bsz'] = bsz mhd_dict['keep_length'] = keep_length batch_log_berns = None if latent_tgt is not None: batch_hypos = torch.cat(batch_hypos, dim=1)[:,:bsz,:] diff_vector = torch.cat(diff_vec_container, dim=1)[:,:bsz,:] if len(log_berns_container): batch_log_berns = torch.cat(log_berns_container, dim=1)[:,:bsz,:] # Recombine into y_hat (final list) y_hat = batch_hypos #Get the winner out of the batched winners dist_loss, topk_idx = self.get_dist_loss(y_hat, latent_tgt, topk=keep_length) wta_loss, wta_idx = self.get_wta_loss(dist_loss, topk_idx, batch_log_berns=batch_log_berns) mhd_dict['yw_hat'] = self.get_win_hypo(y_hat, wta_idx) #mhd_dict['yk_hat'] = self.ddb_layer.get_topk_hypo(y_hat, topk_idx) mhd_dict['wta_idx'] = wta_idx mhd_dict['topk_idx'] = topk_idx mhd_dict['wta_loss'] = wta_loss.mean() mhd_dict['diff_vector'] = diff_vector mhd_dict['y_hat'] = y_hat return mhd_dict def _forward(self, x, y=None, wta_idx=None, wta_loss=None, hypothese_count=1, keep_idx=None, **kwargs): hypotheses = self.sample(x, hypothese_count=hypothese_count, keep_idx=keep_idx) if (y is None and wta_idx is None and wta_loss is None): return hypotheses, None if y is not None: wta_loss, topk_idx = self.get_dist_loss(hypotheses, y) wta_idx = topk_idx[:,0] win_hypo = self.get_win_hypo(hypotheses, wta_idx) return win_hypo, wta_loss.mean(-1) def get_win_hypo(self, hypotheses, wta_idx): batch_list = torch.arange(hypotheses.size(0)) winner_hypo = hypotheses[batch_list, wta_idx, :] return winner_hypo def get_topk_hypo(self, hypotheses, topk_idx): topk_hypos = [batch_hypo[idx, :] for batch_hypo, idx in zip(hypotheses, topk_idx)] topk_hypos = torch.stack(topk_hypos, dim=0) return topk_hypos def create_spatial_mask(self, x): m_shape = x.shape[:2] repeat_frame = [1,1, x.shape[-2], x.shape[-1]] m_prob = torch.ones(m_shape, device=x.device, requires_grad=x.requires_grad, dtype=self.dtype_float) * (1. - self.dropout_rate) m = torch.bernoulli(m_prob).unsqueeze(-1).unsqueeze(-1) #[bz, chn, 1, 1] m = m.repeat(repeat_frame) return m def create_sample_mask(self, x): m_shape = [x.shape[0], x.shape[-2], x.shape[-1]] repeat_frame = [1,x.shape[1], 1, 1] m_prob = torch.ones(m_shape, device=x.device, requires_grad=x.requires_grad, dtype=self.dtype_float) * (1. - self.dropout_rate) m = torch.bernoulli(m_prob).unsqueeze(1) #[bz, 1, dim, dim] m = m.repeat(repeat_frame) return m def sample(self, x, hypothese_count, keep_idx=None): #Infer the repeat structure x_shape = list(x.shape) bs, x_dim, y_dim = x_shape[0], x_shape[-1], x_shape[-2] repeat_frame = [1] * (len(x_shape) + 1) repeat_frame[self.hypothese_dim] = hypothese_count #Repeat for single forward pass sampling x_repeat = x.unsqueeze(self.hypothese_dim).repeat(repeat_frame) #Flatten along batch and sample axis x_repeat = torch.flatten(x_repeat, start_dim=0, end_dim=1) # Apply dropout here if self.mask_type == 'spatial': mask = self.create_spatial_mask(x=x_repeat) elif self.mask_type == 'sample': mask = self.create_sample_mask(x=x_repeat) elif self.mask_type == 'and': spatial_mask = self.create_spatial_mask(x=x_repeat) sample_mask = self.create_sample_mask(x=x_repeat) mask = spatial_mask * sample_mask elif self.mask_type == 'or': spatial_mask = self.create_spatial_mask(x=x_repeat) sample_mask = self.create_sample_mask(x=x_repeat) mask = torch.clip(spatial_mask + sample_mask, max=1.0) else: raise ValueError("Invalid mask type given") distd_hidden = x_repeat * mask # [bs * hypothese_count, out_dim, x, y] output = self.decoder(distd_hidden) output = output.reshape((bs, hypothese_count, self.out_dim, x_dim, y_dim)) return output def get_dist_loss(self, hypotheses, y, topk=None): hypothese_count = hypotheses.size(self.hypothese_dim) y_shape = [1] * len(hypotheses.shape) y_shape[self.hypothese_dim] = hypothese_count #Create copies of y to match sample length y_expanded = y.unsqueeze(self.hypothese_dim).repeat(y_shape) dist_loss = self.dist_loss(hypotheses, y_expanded) dist_loss = self.loss_reduce(dist_loss, dim=self.loss_reduce_dims) #Get the sample with the lowest meta loss if topk is None: topk = min(self.topk, dist_loss.size(1)) topk_idx = torch.topk(dist_loss, dim=-1, largest=False, sorted=True, k=topk)[1] return dist_loss, topk_idx def get_wta_loss(self, dist_loss, topk_idx, **kwargs): wta_idx = topk_idx[:,0] #Create a mask based on the sample count wta_mask = F.one_hot(wta_idx, dist_loss.size(-1)) #Get Winner-Takes-All Loss by using the mask wta_loss = (dist_loss * wta_mask).sum(-1) return wta_loss, wta_idx