# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch from thirdparty.taming.losses.lpips import LPIPS class VQ2VAE(torch.nn.Module): def __init__(self, cfg, encoder1, encoder2, decoder, dec_t, vq_layer1, vq_layer2, use_disc=False, seq_dim=3, hypothese_dim=1, decoder_weight=1.0, decoder_loss='winner', recon_loss='mse', hypothese_bsz=32, residual=True, debug=False, eval_cfg=None, **kwargs): super().__init__() assert vq_layer1.name in ['vq'], "Uncompatible codebook type: {}".format(vq_layer1.name) self.debug = debug self.cfg = cfg self.eval_cfg = eval_cfg self.enc_b = encoder1 self.enc_t = encoder2 self.dec_t = dec_t self.decoder = decoder self.vq_layer_t = vq_layer2 self.vq_layer_b = vq_layer1 self.seq_dim = seq_dim self.hypothese_dim = hypothese_dim self.emb_dim1 = vq_layer1.emb_dim self.emb_dim2 = vq_layer2.emb_dim self.hypothese_bsz = hypothese_bsz self.decoder_loss = decoder_loss self.recon_loss = recon_loss self.residual = residual self.decoder_weight = decoder_weight self.use_disc = use_disc if self.use_disc: self.perceptual_loss = LPIPS().eval() self.perceptual_weight = 1.0 self.proj_t = torch.nn.Conv2d(self.enc_t.out_channel, self.emb_dim2, 1) self.proj_b = torch.nn.Conv2d(self.enc_b.out_channel + self.emb_dim1, self.emb_dim1, 1) up_sample_ratio = int(self.seq_dim[0] / seq_dim[-1]) if self.seq_dim[-1] > 1: up_sample_ratio += 1 self.upsample_t = torch.nn.ConvTranspose2d( self.emb_dim2, self.emb_dim2, up_sample_ratio ) def forward(self, x, optimizer_idx=0, split='train', return_outputs=False, **kwargs): log = {} outputs = {} loss = 0 enc_b = self.enc_b(x) enc_t = self.enc_t(enc_b) q_zt = self.proj_t(enc_t) assert q_zt.size(-1) == self.seq_dim[-1], "Check encoder2 params. Wrong encoder output dimensions: q_zt {} => {}".format(q_zt.shape, self.seq_dim[-1]) q_zt, _, cb_loss1 = self.vq_layer_t(q_zt, is_indices=False) dec_t = self.dec_t(q_zt) enc_b = torch.cat([dec_t, enc_b], 1) q_zb = self.proj_b(enc_b) assert q_zb.size(-1) == self.seq_dim[0], "Check encoder1 params. Wrong encoder output dimensions: q_zb {} => {}".format(q_zb.shape, self.seq_dim[0]) z_c, _, cb_loss2 = self.vq_layer_b(q_zb, is_indices=False) if self.upsample_t is not None: z_d = self.upsample_t(q_zt) else: z_d = q_zt z_dc = torch.cat([z_d, z_c], 1) x_hat = self.decoder(z_dc) cb_loss = (cb_loss1 + cb_loss2).mean() disc_inputs = {} if not self.use_disc: recon_loss = self.reconstruction_loss(x, x_hat) loss += recon_loss / self.decoder_weight loss += cb_loss log['{}/mse_loss'.format(split)] = recon_loss.clone().detach().cpu() else: rec_loss = torch.abs(x.contiguous() - x_hat.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(x.contiguous(), x_hat.contiguous()) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) nll_loss = rec_loss nll_loss = torch.mean(nll_loss) log["{}/nll_loss".format(split)] = nll_loss.detach().mean() log["{}/rec_loss".format(split)] = rec_loss.detach().mean() log["{}/p_loss".format(split)] = p_loss.detach().mean() disc_inputs['nll_loss'] = nll_loss disc_inputs['cb_loss'] = cb_loss log['{}/cb1_loss'.format(split)] = cb_loss1.mean().clone().detach().cpu() log['{}/cb2_loss'.format(split)] = cb_loss2.mean().clone().detach().cpu() if return_outputs: outputs['q_z1'] = q_zt.detach().cpu() outputs['q_z2'] = q_zb.detach().cpu() outputs['z_c'] = z_c.detach().cpu() outputs['xc_hat'] = x_hat.detach().cpu() return x_hat, disc_inputs, loss, outputs, log def get_last_layer(self): return self.decoder.conv_out.weight def sample_codebook(self, x): enc_b = self.enc_b(x) enc_t = self.enc_t(enc_b) q_zt = self.proj_t(enc_t) cb_indices_t = self.vq_layer_t.get_code_indices(q_zt) q_zt, _, _ = self.vq_layer_t(q_zt, is_indices=False) dec_t = self.dec_t(q_zt) enc_b = torch.cat([dec_t, enc_b], 1) q_zb = self.proj_b(enc_b) cb_indices_b = self.vq_layer_b.get_code_indices(q_zb) batch_size, _, x_dim, y_dim = q_zb.shape cb_indices_b = cb_indices_b.reshape([batch_size, x_dim, y_dim]) batch_size, _, x_dim, y_dim = q_zt.shape cb_indices_t = cb_indices_t.reshape([batch_size, x_dim, y_dim]) return cb_indices_t, cb_indices_b def reconstruct_from_indices(self, ind_1, ind_2, cond=None, hypothese_count=1): outputs = {} q_zt = self.vq_layer_t(ind_1, is_indices=True) if self.upsample_t is not None: z_d = self.upsample_t(q_zt) else: z_d = q_zt z_c = self.vq_layer_b(ind_2, is_indices=True) z_c = torch.cat([z_d, z_c], 1) xc_hat = self.decoder(z_c) outputs['xc_hat'] = xc_hat.detach().cpu() outputs['z_c'] = z_c.detach().cpu() return outputs def reconstruction_loss(self, x, x_hat): if self.recon_loss == 'mse': return torch.mean((x - x_hat) ** 2) def get_embeddings(self, latent_nb=0, **kwargs): if latent_nb == 1: l = self.vq_layer_b else: l = self.vq_layer_t return l.get_embed(**kwargs)