mvq / models / mvq_model.py
mvq_model.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import torch
import numpy as np
from thirdparty.taming.losses.lpips import LPIPS

class MVQAE(torch.nn.Module):
    def __init__(self, cfg, 
            encoder1, 
            decoder, 
            codebook, 
            num_latent_space=1, 
            encoder2=None, 
            mhd_layer=None, 
            use_disc=False, 
            seq_dim=3, 
            hypothese_dim=1,  
            eval_cfg=False, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg
        self.eval_cfg = eval_cfg

        self.enc_1 = encoder1
        self.enc_2 = encoder2

        self.decoder = decoder
        self.codebook = codebook

        self.seq_dim = seq_dim
        self.hypothese_dim = hypothese_dim

        self.num_latent_space = num_latent_space
        
        self.use_mhd = True if mhd_layer is not None else False
        self.mhd_layer = mhd_layer

        self.use_disc = use_disc
        if self.use_disc:
            self.perceptual_loss = LPIPS().eval()
            self.perceptual_weight = 1.0

        self.proj_1, self.proj_2 = None, None
        out_ch1 = encoder1.out_channel if hasattr(encoder1, 'out_channel') else encoder1.ch
        self.proj_1 = torch.nn.Conv2d(out_ch1, self.codebook.emb_dims[0], 1)
        if self.num_latent_space == 2:
            out_ch2 = encoder2.out_channel if hasattr(encoder2, "out_channel") else encoder2.out_ch 
            self.proj_2 = torch.nn.Conv2d(out_ch2, self.codebook.emb_dims[1], 1)

    def unflat_tensor(self, x, hypothese_count):
        _, channel_dim, x_dim, y_dim = x.shape
        return x.reshape((-1, hypothese_count, channel_dim, x_dim, y_dim))

    def encode(self, x):

        encoded = []

        enc_1 = self.enc_1(x)
        q_zc = self.proj_1(enc_1)

        assert q_zc.size(-1) == self.seq_dim[0], "Check encoder1 params. Wrong encoder output dimensions: vqc_inp {} => {}".format(q_zc.size(-1), self.seq_dim[0])
        encoded.append(q_zc)

        if self.num_latent_space == 2:
            q_zd = self.proj_2(self.enc_2(enc_1))
            assert q_zd.size(-1) == self.seq_dim[-1], "Check encoder2 params. Wrong encoder output dimensions: vqd_inp {} => {}".format(q_zd.shape, self.seq_dim[-1])
            encoded.append(q_zd)
        
        return encoded

    def quantize(self, encoded):
        zs, xs_quantised, commit_losses, metrics = self.codebook(encoded)

        return zs, xs_quantised, commit_losses, metrics

    def forward(self, x, optimizer_idx=0, hypothese_count=16, split='train', return_outputs=False, **kwargs):

        log, outputs = {}, {}
        loss = 0

        encoded = self.encode(x)

        zs, xs_quantised, commit_losses, cb_metrics = self.quantize(encoded)

        mhd_dict = self.mhd_layer(
            z_top=zs[1], 
            z_bottom=zs[0], 
            hypothese_count=hypothese_count, 
            latent_tgt=encoded[0],
            split=split)
        
        yw_hat = mhd_dict.get('yw_hat')
        diffw_vector = yw_hat - zs[0]

        if return_outputs: 
            outputs['y_hat']        = mhd_dict['y_hat'].detach().cpu()
            outputs['z_bottom']     = zs[0].detach().cpu()

            outputs['diff_vector']  = mhd_dict['diff_vector'] .detach().cpu()
            outputs['diffw_vector'] = diffw_vector.detach().cpu()

        wta_loss = mhd_dict.get('wta_loss')
        if optimizer_idx == 0 and wta_loss is not None:
            loss += wta_loss 
            log['{}/wta_loss'.format(split)] = wta_loss.clone().detach().cpu()

        x_hat = self.decoder(yw_hat)

        cb_loss = (commit_losses[0] + commit_losses[1]).mean()

        disc_inputs = {}
        if not self.use_disc: 
            recon_loss = self.reconstruction_loss(x, x_hat)
            loss += recon_loss
            loss += cb_loss
            log['{}/mse_loss'.format(split)] = recon_loss.clone().detach().cpu()
        else:
            rec_loss = torch.abs(x.contiguous() - x_hat.contiguous())
            if self.perceptual_weight > 0:
                p_loss = self.perceptual_loss(x.contiguous(), x_hat.contiguous())
                rec_loss = rec_loss + self.perceptual_weight * p_loss
            else:
                p_loss = torch.tensor([0.0])

            nll_loss = rec_loss
            nll_loss = torch.mean(nll_loss)

            log["{}/nll_loss".format(split)] = nll_loss.detach().mean()
            log["{}/rec_loss".format(split)] = rec_loss.detach().mean()
            log["{}/p_loss".format(split)] = p_loss.detach().mean()

            disc_inputs['nll_loss'] = nll_loss
            disc_inputs['cb_loss'] = cb_loss
            disc_inputs['wta_loss'] = wta_loss

        log[f'{split}/cb1_loss'] = commit_losses[0].mean().clone().detach().cpu()
        if len(cb_metrics):
            log[f'{split}/1fit'] = cb_metrics[0]['fit'].detach().cpu()
            log[f'{split}/1prenorm'] = cb_metrics[0]['prenorm'].detach().cpu()
            log[f'{split}/1entr'] = cb_metrics[0]['entropy'].detach().cpu()
            log[f'{split}/1used_curr'] = cb_metrics[0]['used_curr'].detach().cpu()
            log[f'{split}/1usage'] = cb_metrics[0]['usage'].detach().cpu()
            log[f'{split}/1dk'] = cb_metrics[0]['dk'].detach().cpu()

        if self.num_latent_space == 2:
            log['{}/cb2_loss'.format(split)] = commit_losses[1].mean().clone().detach().cpu()
            if len(cb_metrics):
                log[f'{split}/2fit'] = cb_metrics[1]['fit'].detach().cpu()
                log[f'{split}/2prenorm'] = cb_metrics[1]['prenorm'].detach().cpu()
                log[f'{split}/2entr'] = cb_metrics[1]['entropy'].detach().cpu()
                log[f'{split}/2used_curr'] = cb_metrics[1]['used_curr'].detach().cpu()
                log[f'{split}/2usage'] = cb_metrics[1]['usage'].detach().cpu()
                log[f'{split}/2dk'] = cb_metrics[1]['dk'].detach().cpu()

        return x_hat, disc_inputs, loss, outputs, log

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    def sample_codebook(self, x):
        
        encoded = self.encode(x)

        assert encoded[0].size(-1) == self.seq_dim[0], "Check encoder1 params. Wrong encoder output dimensions: vq1_inp {} => {}".format(vq1_inp.shape, self.seq_dim[0])
        assert encoded[1].size(-1) == self.seq_dim[-1], "Check encoder2 params. Wrong encoder output dimensions: vq2_inp {} => {}".format(vq2_inp.shape, self.seq_dim[-1])

        codebook_indices1, _ = self.codebook.level_blocks[0].get_code_indices(encoded[0])
        codebook_indices2, _ = self.codebook.level_blocks[1].get_code_indices(encoded[1])

        bsz = encoded[0].size(0)

        seq_len1 = int(self.seq_dim[0])
        seq_len2 = int(self.seq_dim[1])

        codebook_indices1 = codebook_indices1.reshape([
            bsz, seq_len1, seq_len1])
        codebook_indices2 = codebook_indices2.reshape([
            bsz, seq_len2, seq_len2])

        return codebook_indices1, codebook_indices2

    def reconstruct_from_indices(self, ind_1, ind_2, hypothese_count=1):
        outputs = {}

        zs = self.codebook.decode([ind_1, ind_2])
        c_base, c_i = zs

        if self.mhd_layer is not None:

            ddb_dict = self.mhd_layer(
                c_base,  hypothese_count, 
                dist_latents=c_i, split='eval')

            y_hat       = ddb_dict['y_hat']
            diff_vector = ddb_dict['diff_vector']
            bsz         = ddb_dict['bsz']

            #Batch decode 
            y_hat_flat  = torch.flatten(y_hat, start_dim=0, end_dim=1)
            xd_hat = self.decoder(y_hat_flat).detach().cpu()
            xd_hat = self.unflat_tensor(xd_hat, bsz)

            #Arange based on distance to center
            _, topk_idx = self.mhd_layer.get_dist_loss(y_hat, c_base, topk=hypothese_count)
            xk_hat = self.mhd_layer.get_topk_hypo(xd_hat, topk_idx)

            outputs['xd_hat']      = xd_hat.detach().cpu()
            outputs['xk_hat']      = xk_hat.detach().cpu() 
            outputs['diff_vector'] = diff_vector.detach().cpu()

        outputs['xc_hat'] = self.decoder(c_base).detach().cpu()
        outputs['c_base'] = c_base.detach().cpu()

        return outputs
    
    def reconstruction_loss(self, x, x_hat):
        return  torch.mean((x - x_hat) ** 2)
    
    def get_embeddings(self, latent_nb=0, **kwargs):
        if latent_nb == 0:
            l = self.codebook.level_blocks[0].embeddings
        else:
            l = self.codebook.level_blocks[1].embeddings
        return l.get_embed(**kwargs)