# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F from .mhd_helper import ( get_reduce_fn, get_dist_loss, MLP1d ) class MHDropoutNetRandom1D(nn.Module): def __init__(self, inp_dim, hidden_dim, out_dim, decoder_cfg, dist_reduce='mean', loss_reduce='mean', loss_reduce_dims=[-3,-2,-1], hypothese_dim=1, dist_loss='mse', gamma=0.25, dropout_rate=0.5, topk=64, bottleneck=False, norm=False, **kwargs): super().__init__() self.dtype_float = torch.float32 self.topk = topk self.out_dim = out_dim self.hidden_dim = hidden_dim self.hypothese_dim = hypothese_dim self.dropout_rate = dropout_rate self.gamma = gamma self.dist_loss = get_dist_loss(dist_loss)(reduction='none') self.loss_reduce_dims = loss_reduce_dims self.loss_reduce = get_reduce_fn(loss_reduce) self.dist_reduce = get_reduce_fn(dist_reduce) self.bottleneck = bottleneck self.norm_1 = nn.GroupNorm(1, inp_dim) if norm and inp_dim > 1 else None if self.bottleneck: self.proj = MLP1d(inp_dim, hidden_dim, hidden_dim, act='relu') dec_inp_dim = hidden_dim else: dec_inp_dim = inp_dim self.decoder = MLP1d(dec_inp_dim, hidden_dim, out_dim, **decoder_cfg) def forward(self, x, y=None, wta_idx=None, wta_loss=None, hypothese_count=1, **kwargs): hypotheses = self.sample(x, hypothese_count=hypothese_count) if (y is None and wta_idx is None and wta_loss is None): return hypotheses if y is not None: wta_loss, topk_idx = self.get_dist_loss(hypotheses, y) wta_idx = topk_idx[:,0] win_hypo = self.get_win_hypo(hypotheses, wta_idx) return win_hypo, wta_loss.mean(-1) def get_win_hypo(self, hypos, wta_idx): batch_list = torch.arange(hypos.size(0)) winner_hypo = hypos[batch_list, wta_idx, :] return winner_hypo def get_topk_hypo(self, hypos, topk_idx): topk_hypos = [batch_hypo[idx, :] for batch_hypo, idx in zip(hypos, topk_idx)] topk_hypos = torch.stack(topk_hypos, dim=0) return topk_hypos def create_mask(self, x): m_shape = x.shape[:2] m_prob = torch.ones(m_shape, dtype=self.dtype_float, device=x.device, requires_grad=x.requires_grad) * (1. - self.dropout_rate) m = torch.bernoulli(m_prob).unsqueeze(-1) return m def sample(self, x, hypothese_count): x = torch.transpose(x, -2, -1) #Infer the repeat structure x_shape = list(x.shape) bsz, seq_len = x_shape[0], x_shape[-1] repeat_frame = [1] * (len(x_shape) + 1) repeat_frame[self.hypothese_dim] = hypothese_count output_shape = [bsz, hypothese_count, seq_len, self.out_dim] if self.norm_1 is not None: x = self.norm_1(x) #Repeat for single forward pass sampling x_repeat = x.unsqueeze(self.hypothese_dim).repeat(repeat_frame) #Flatten along batch and sample axis x_repeat = torch.flatten(x_repeat, start_dim=0, end_dim=1) #Reverse channels #x_repeat = torch.transpose(x_repeat, -2, -1) #Optional: Create bottleneck to reduce space of all possible masks if self.bottleneck: hidden = self.proj(x_repeat) else: hidden = x_repeat # Apply dropout here distd_hidden = hidden * self.create_mask(hidden) # [bs * hypo_count, out_dim, x, y] output = self.decoder(distd_hidden) output = torch.transpose(output, -2, -1) output = output.reshape(output_shape) return output def get_dist_loss(self, hypos, y, topk=None): hypo_count = hypos.size(self.hypothese_dim) y_shape = [1] * len(hypos.shape) y_shape[self.hypothese_dim] = hypo_count #Create copies of y to match sample length y_expanded = y.unsqueeze(self.hypothese_dim).repeat(y_shape) dist_loss = self.dist_loss(hypos, y_expanded) dist_loss = self.loss_reduce(dist_loss, dim=self.loss_reduce_dims) #Get the sample with the lowest meta loss if topk is None: topk = min(self.topk, dist_loss.size(1)) topk_idx = torch.topk(dist_loss, dim=-1, largest=False, sorted=True, k=topk)[1] return dist_loss, topk_idx def get_wta_loss(self, dist_loss, topk_idx): wta_idx = topk_idx[:,0] #Create a mask based on the sample count wta_mask = F.one_hot(wta_idx, dist_loss.size(-1)) #Get Winner-Takes-All Loss by using the mask wta_loss = (dist_loss * wta_mask).sum(-1) return wta_loss, wta_idx