# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import numpy as np from thirdparty.taming.losses.lpips import LPIPS class MVQAE(torch.nn.Module): def __init__(self, cfg, encoder1, decoder, codebook, num_latent_space=1, encoder2=None, mhd_layer=None, use_disc=False, seq_dim=3, hypothese_dim=1, eval_cfg=False, **kwargs): super().__init__(**kwargs) self.cfg = cfg self.eval_cfg = eval_cfg self.enc_1 = encoder1 self.enc_2 = encoder2 self.decoder = decoder self.codebook = codebook self.seq_dim = seq_dim self.hypothese_dim = hypothese_dim self.num_latent_space = num_latent_space self.use_mhd = True if mhd_layer is not None else False self.mhd_layer = mhd_layer self.use_disc = use_disc if self.use_disc: self.perceptual_loss = LPIPS().eval() self.perceptual_weight = 1.0 self.proj_1, self.proj_2 = None, None out_ch1 = encoder1.out_channel if hasattr(encoder1, 'out_channel') else encoder1.ch self.proj_1 = torch.nn.Conv2d(out_ch1, self.codebook.emb_dims[0], 1) if self.num_latent_space == 2: out_ch2 = encoder2.out_channel if hasattr(encoder2, "out_channel") else encoder2.out_ch self.proj_2 = torch.nn.Conv2d(out_ch2, self.codebook.emb_dims[1], 1) def unflat_tensor(self, x, hypothese_count): _, channel_dim, x_dim, y_dim = x.shape return x.reshape((-1, hypothese_count, channel_dim, x_dim, y_dim)) def encode(self, x): encoded = [] enc_1 = self.enc_1(x) q_zc = self.proj_1(enc_1) assert q_zc.size(-1) == self.seq_dim[0], "Check encoder1 params. Wrong encoder output dimensions: vqc_inp {} => {}".format(q_zc.size(-1), self.seq_dim[0]) encoded.append(q_zc) if self.num_latent_space == 2: q_zd = self.proj_2(self.enc_2(enc_1)) assert q_zd.size(-1) == self.seq_dim[-1], "Check encoder2 params. Wrong encoder output dimensions: vqd_inp {} => {}".format(q_zd.shape, self.seq_dim[-1]) encoded.append(q_zd) return encoded def quantize(self, encoded): zs, xs_quantised, commit_losses, metrics = self.codebook(encoded) return zs, xs_quantised, commit_losses, metrics def forward(self, x, optimizer_idx=0, hypothese_count=16, split='train', return_outputs=False, **kwargs): log, outputs = {}, {} loss = 0 encoded = self.encode(x) zs, xs_quantised, commit_losses, cb_metrics = self.quantize(encoded) mhd_dict = self.mhd_layer( z_top=zs[1], z_bottom=zs[0], hypothese_count=hypothese_count, latent_tgt=encoded[0], split=split) yw_hat = mhd_dict.get('yw_hat') diffw_vector = yw_hat - zs[0] if return_outputs: outputs['y_hat'] = mhd_dict['y_hat'].detach().cpu() outputs['z_bottom'] = zs[0].detach().cpu() outputs['diff_vector'] = mhd_dict['diff_vector'] .detach().cpu() outputs['diffw_vector'] = diffw_vector.detach().cpu() wta_loss = mhd_dict.get('wta_loss') if optimizer_idx == 0 and wta_loss is not None: loss += wta_loss log['{}/wta_loss'.format(split)] = wta_loss.clone().detach().cpu() x_hat = self.decoder(yw_hat) cb_loss = (commit_losses[0] + commit_losses[1]).mean() disc_inputs = {} if not self.use_disc: recon_loss = self.reconstruction_loss(x, x_hat) loss += recon_loss loss += cb_loss log['{}/mse_loss'.format(split)] = recon_loss.clone().detach().cpu() else: rec_loss = torch.abs(x.contiguous() - x_hat.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(x.contiguous(), x_hat.contiguous()) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) nll_loss = rec_loss nll_loss = torch.mean(nll_loss) log["{}/nll_loss".format(split)] = nll_loss.detach().mean() log["{}/rec_loss".format(split)] = rec_loss.detach().mean() log["{}/p_loss".format(split)] = p_loss.detach().mean() disc_inputs['nll_loss'] = nll_loss disc_inputs['cb_loss'] = cb_loss disc_inputs['wta_loss'] = wta_loss log[f'{split}/cb1_loss'] = commit_losses[0].mean().clone().detach().cpu() if len(cb_metrics): log[f'{split}/1fit'] = cb_metrics[0]['fit'].detach().cpu() log[f'{split}/1prenorm'] = cb_metrics[0]['prenorm'].detach().cpu() log[f'{split}/1entr'] = cb_metrics[0]['entropy'].detach().cpu() log[f'{split}/1used_curr'] = cb_metrics[0]['used_curr'].detach().cpu() log[f'{split}/1usage'] = cb_metrics[0]['usage'].detach().cpu() log[f'{split}/1dk'] = cb_metrics[0]['dk'].detach().cpu() if self.num_latent_space == 2: log['{}/cb2_loss'.format(split)] = commit_losses[1].mean().clone().detach().cpu() if len(cb_metrics): log[f'{split}/2fit'] = cb_metrics[1]['fit'].detach().cpu() log[f'{split}/2prenorm'] = cb_metrics[1]['prenorm'].detach().cpu() log[f'{split}/2entr'] = cb_metrics[1]['entropy'].detach().cpu() log[f'{split}/2used_curr'] = cb_metrics[1]['used_curr'].detach().cpu() log[f'{split}/2usage'] = cb_metrics[1]['usage'].detach().cpu() log[f'{split}/2dk'] = cb_metrics[1]['dk'].detach().cpu() return x_hat, disc_inputs, loss, outputs, log def get_last_layer(self): return self.decoder.conv_out.weight def sample_codebook(self, x): encoded = self.encode(x) assert encoded[0].size(-1) == self.seq_dim[0], "Check encoder1 params. Wrong encoder output dimensions: vq1_inp {} => {}".format(vq1_inp.shape, self.seq_dim[0]) assert encoded[1].size(-1) == self.seq_dim[-1], "Check encoder2 params. Wrong encoder output dimensions: vq2_inp {} => {}".format(vq2_inp.shape, self.seq_dim[-1]) codebook_indices1, _ = self.codebook.level_blocks[0].get_code_indices(encoded[0]) codebook_indices2, _ = self.codebook.level_blocks[1].get_code_indices(encoded[1]) bsz = encoded[0].size(0) seq_len1 = int(self.seq_dim[0]) seq_len2 = int(self.seq_dim[1]) codebook_indices1 = codebook_indices1.reshape([ bsz, seq_len1, seq_len1]) codebook_indices2 = codebook_indices2.reshape([ bsz, seq_len2, seq_len2]) return codebook_indices1, codebook_indices2 def reconstruct_from_indices(self, ind_1, ind_2, hypothese_count=1): outputs = {} zs = self.codebook.decode([ind_1, ind_2]) c_base, c_i = zs if self.mhd_layer is not None: ddb_dict = self.mhd_layer( c_base, hypothese_count, dist_latents=c_i, split='eval') y_hat = ddb_dict['y_hat'] diff_vector = ddb_dict['diff_vector'] bsz = ddb_dict['bsz'] #Batch decode y_hat_flat = torch.flatten(y_hat, start_dim=0, end_dim=1) xd_hat = self.decoder(y_hat_flat).detach().cpu() xd_hat = self.unflat_tensor(xd_hat, bsz) #Arange based on distance to center _, topk_idx = self.mhd_layer.get_dist_loss(y_hat, c_base, topk=hypothese_count) xk_hat = self.mhd_layer.get_topk_hypo(xd_hat, topk_idx) outputs['xd_hat'] = xd_hat.detach().cpu() outputs['xk_hat'] = xk_hat.detach().cpu() outputs['diff_vector'] = diff_vector.detach().cpu() outputs['xc_hat'] = self.decoder(c_base).detach().cpu() outputs['c_base'] = c_base.detach().cpu() return outputs def reconstruction_loss(self, x, x_hat): return torch.mean((x - x_hat) ** 2) def get_embeddings(self, latent_nb=0, **kwargs): if latent_nb == 0: l = self.codebook.level_blocks[0].embeddings else: l = self.codebook.level_blocks[1].embeddings return l.get_embed(**kwargs)