mvq / models / vq2_model.py
vq2_model.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
from thirdparty.taming.losses.lpips import LPIPS

class VQ2VAE(torch.nn.Module):
    def __init__(self, cfg, encoder1, encoder2, decoder, 
            dec_t, vq_layer1, vq_layer2, 
            use_disc=False, seq_dim=3, hypothese_dim=1, 
            decoder_weight=1.0, decoder_loss='winner', recon_loss='mse',
            hypothese_bsz=32, residual=True, debug=False, eval_cfg=None, 
            **kwargs):
        super().__init__()
        assert vq_layer1.name in ['vq'], "Uncompatible codebook type: {}".format(vq_layer1.name)

        self.debug = debug
        self.cfg = cfg
        self.eval_cfg = eval_cfg

        self.enc_b = encoder1
        self.enc_t = encoder2

        self.dec_t = dec_t
        self.decoder = decoder

        self.vq_layer_t = vq_layer2
        self.vq_layer_b = vq_layer1

        self.seq_dim = seq_dim
        self.hypothese_dim = hypothese_dim

        self.emb_dim1 = vq_layer1.emb_dim
        self.emb_dim2 = vq_layer2.emb_dim

        self.hypothese_bsz = hypothese_bsz
        self.decoder_loss = decoder_loss
        self.recon_loss = recon_loss

        self.residual = residual
        self.decoder_weight = decoder_weight

        self.use_disc = use_disc
        if self.use_disc:
            self.perceptual_loss = LPIPS().eval()
            self.perceptual_weight = 1.0

        self.proj_t = torch.nn.Conv2d(self.enc_t.out_channel, self.emb_dim2, 1)
        self.proj_b = torch.nn.Conv2d(self.enc_b.out_channel + self.emb_dim1, self.emb_dim1, 1)

        up_sample_ratio = int(self.seq_dim[0] / seq_dim[-1])
        if self.seq_dim[-1] > 1:
            up_sample_ratio += 1

        self.upsample_t = torch.nn.ConvTranspose2d(
                        self.emb_dim2, self.emb_dim2, up_sample_ratio
                    )

    def forward(self, x, optimizer_idx=0, split='train', return_outputs=False, **kwargs):
        
        log = {}
        outputs = {}
        loss = 0

        enc_b = self.enc_b(x)

        enc_t = self.enc_t(enc_b)

        q_zt = self.proj_t(enc_t)

        assert q_zt.size(-1) == self.seq_dim[-1], "Check encoder2 params. Wrong encoder output dimensions: q_zt {} => {}".format(q_zt.shape, self.seq_dim[-1])

        q_zt, _, cb_loss1 = self.vq_layer_t(q_zt, is_indices=False)

        dec_t = self.dec_t(q_zt)

        enc_b = torch.cat([dec_t, enc_b], 1)

        q_zb = self.proj_b(enc_b)
        
        assert q_zb.size(-1) == self.seq_dim[0], "Check encoder1 params. Wrong encoder output dimensions: q_zb {} => {}".format(q_zb.shape, self.seq_dim[0])

        z_c, _, cb_loss2 = self.vq_layer_b(q_zb, is_indices=False)

        if self.upsample_t is not None:
            z_d = self.upsample_t(q_zt)
        else:
            z_d = q_zt

        z_dc = torch.cat([z_d, z_c], 1)
        
        x_hat = self.decoder(z_dc)


        cb_loss = (cb_loss1 + cb_loss2).mean()
        disc_inputs = {}

        if not self.use_disc: 
            recon_loss = self.reconstruction_loss(x, x_hat)
            loss += recon_loss / self.decoder_weight
            loss += cb_loss
            log['{}/mse_loss'.format(split)] = recon_loss.clone().detach().cpu()
        else:
            rec_loss = torch.abs(x.contiguous() - x_hat.contiguous())
            if self.perceptual_weight > 0:
                p_loss = self.perceptual_loss(x.contiguous(), x_hat.contiguous())
                rec_loss = rec_loss + self.perceptual_weight * p_loss
            else:
                p_loss = torch.tensor([0.0])

            nll_loss = rec_loss
            nll_loss = torch.mean(nll_loss)

            log["{}/nll_loss".format(split)] = nll_loss.detach().mean()
            log["{}/rec_loss".format(split)] = rec_loss.detach().mean()
            log["{}/p_loss".format(split)] = p_loss.detach().mean()

            disc_inputs['nll_loss'] = nll_loss
            disc_inputs['cb_loss'] = cb_loss

        log['{}/cb1_loss'.format(split)] = cb_loss1.mean().clone().detach().cpu()
        log['{}/cb2_loss'.format(split)] = cb_loss2.mean().clone().detach().cpu()

        if return_outputs:
            outputs['q_z1'] = q_zt.detach().cpu() 
            outputs['q_z2'] = q_zb.detach().cpu() 
            outputs['z_c'] = z_c.detach().cpu()
            outputs['xc_hat'] = x_hat.detach().cpu()

        return x_hat, disc_inputs, loss, outputs, log

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    def sample_codebook(self, x):
        enc_b = self.enc_b(x)
        enc_t = self.enc_t(enc_b)
        q_zt = self.proj_t(enc_t)

        cb_indices_t = self.vq_layer_t.get_code_indices(q_zt)

        q_zt, _, _ = self.vq_layer_t(q_zt, is_indices=False)
        dec_t = self.dec_t(q_zt)
        enc_b = torch.cat([dec_t, enc_b], 1)
        q_zb = self.proj_b(enc_b)

        cb_indices_b = self.vq_layer_b.get_code_indices(q_zb)

        batch_size, _, x_dim, y_dim = q_zb.shape
        cb_indices_b = cb_indices_b.reshape([batch_size, x_dim, y_dim])

        batch_size, _, x_dim, y_dim = q_zt.shape
        cb_indices_t = cb_indices_t.reshape([batch_size, x_dim, y_dim])
        return cb_indices_t, cb_indices_b

    def reconstruct_from_indices(self, ind_1, ind_2, cond=None, hypothese_count=1):
        outputs = {}

        q_zt = self.vq_layer_t(ind_1, is_indices=True)

        if self.upsample_t is not None:
            z_d = self.upsample_t(q_zt)
        else:
            z_d = q_zt

        z_c = self.vq_layer_b(ind_2, is_indices=True)

        z_c = torch.cat([z_d, z_c], 1)
        xc_hat = self.decoder(z_c)

        outputs['xc_hat'] = xc_hat.detach().cpu()
        outputs['z_c'] = z_c.detach().cpu()

        return outputs
    
    def reconstruction_loss(self, x, x_hat):
        if self.recon_loss == 'mse':
            return  torch.mean((x - x_hat) ** 2)
    
    def get_embeddings(self, latent_nb=0, **kwargs):
        if latent_nb == 1:
            l = self.vq_layer_b
        else:
            l = self.vq_layer_t
        return l.get_embed(**kwargs)