mvq / models / build.py
build.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import copy

import numpy as np
import torch

from torch.nn.parallel import DistributedDataParallel as DDP

from .mvq_model import MVQAE
from .vq2_model import VQ2VAE

from models.conv import (
  Encoder, Decoder, 
  EncoderBN, DecoderBN,
  get_coder_cfg
)

from thirdparty.taming.diffusionmodules.model import Encoder as EncoderDiffusion
from thirdparty.taming.diffusionmodules.model import Decoder as DecoderDiffusion

from models.mh_dropout import (
  MHDropoutNetRandom2D,
)

from models.vq import VectorQuantizer, VQMulti

from thirdparty.vqvae2.pixel_snail import PixelSNAIL
from thirdparty.taming.losses.vqperceptual import VQLPIPSWithDiscriminator

from .gpt_wrap import GPTWrapper

def init_model(cfg, stage):
  use_distributed = cfg.torch_dist.use
  device_id       = cfg.device_id
  use_disc        = cfg.model.gan.use

  models, opts = {}, {}
  models['vae'], opts['vae'] = init_vae_model(cfg)

  if use_disc: 
    models['disc'], opts['disc'] = init_disc_model(cfg)

  if stage == 'seq':
    m, o = init_seq_model(cfg, device_id)
    models = {**models, **m}
    opts = {**opts, **o}

  # if cfg.rank == 0: 
  #   for m_key, model in models.items():
  #     total_params = sum(p.numel() for p in model.parameters())
  #     print("Loaded Model: {:9}  Parameters={}".format(m_key, total_params))

  for s in models.keys():
    if use_distributed and \
      models.get(s) is not None:
        
      models[s] = to_distributed(
          model=models[s], 
          device_id=device_id, 
          )
    elif device_id is not None and device_id is not 'cpu':
      if models[s] is not None:
        models[s].to(f'cuda:{device_id}')
  return models, opts

def to_distributed(model, device_id):
  model.to(f'cuda:{device_id}')
  model = DDP(model, device_ids=[device_id])
  return model

def init_vae_model(cfg):

  debug    = cfg.debug
  use_fp16 = cfg.fp16.use
  
  use_disc = cfg.model.gan.use
  emb_dim1 = cfg.model.vq.emb_dim1
  emb_dim2 = cfg.model.vq.emb_dim2

  use_mhd = cfg.model.mhd.use
  in_channels = cfg.data.in_channels
  num_latent_space = cfg.model.num_latent_space

  n_emb11 = cfg.model.vq.n_emb1
  n_emb12 = cfg.model.vq.n_emb2
  n_emb12 = n_emb11 if n_emb12 < 0 else n_emb12

  if cfg.rank == 0:
    print("Model config          : backbone={}  codebook={}  num_z={}  mh-dropout={}  discriminator={}".format(
      cfg.model.backbone, cfg.model.vq.name, num_latent_space, int(use_mhd), int(use_disc),))

  #################################
  # 1. Init encoder and decoders
  #################################


  dsf1 = cfg.model.down_sampling_factor1
  dsf2 = cfg.model.down_sampling_factor2

  enc_cfg, enc2_cfg, dec_cfg, dec2_cfg = get_coder_cfg(cfg, dsf1)

  encoder_kwargs = {**enc_cfg}
  encoder_kwargs['in_channel'] = in_channels
  encoder_kwargs['out_channel'] = enc_cfg.channel

  d_in = emb_dim1 if dec_cfg.in_channel == 0 else dec_cfg.in_channel

  decoder_kwargs = {**dec_cfg}
  decoder_kwargs['in_channel'] = d_in
  decoder_kwargs['out_channel'] = in_channels

  if cfg.model.backbone == 'vq2':
    #Build top decoder
    dec_t_kwargs = copy.deepcopy(decoder_kwargs)
    dec_t_kwargs['in_channel'] = emb_dim2
    dec_t_kwargs['out_channel'] = emb_dim1

    if dsf2 == 16:
      dec_t_kwargs['stride'] = 4
      dec_t_kwargs['kernels'] = [2,2,4]
    elif dsf2 == 4:
      dec_t_kwargs['stride'] = 2
      dec_t_kwargs['kernels'] = [2,1]
    elif dsf2 == 1:
      dec_t_kwargs['stride'] = 1
      dec_t_kwargs['kernels'] = [3,1]
    else:
      raise NotImplementedError("Downsampling factor 2 not supported: {}".format(dsf2))
    
    dec_t = get_decoder(cfg.model.coder.name, **dec_t_kwargs)

    #Bottom decoder takes a concat of top and bottom
    if not use_mhd:
      decoder_kwargs['in_channel'] = emb_dim1 + emb_dim2

  encoder = get_encoder(cfg.model.coder.name, **encoder_kwargs)
  decoder = get_decoder(cfg.model.coder.name, **decoder_kwargs)

  if num_latent_space == 2:
    encoder2_kwargs = {
      'in_channel': in_channels,
      'out_channel': enc2_cfg.channel, 
      'channel': enc2_cfg.channel,
      'n_res_block': enc2_cfg.n_res_block,
      'n_res_channel': enc2_cfg.n_res_channel,
      'stride': enc2_cfg.stride,
      'kernels': enc2_cfg.kernels,
      'res_kernels': enc2_cfg.res_kernels,
      'act': enc2_cfg.act
      }
    encoder2_kwargs['in_channel'] = enc_cfg.channel

    if dsf2 == 16:
      encoder2_kwargs['stride'] = 4
      encoder2_kwargs['kernels'] = [4,3,3]

    elif dsf2 == 4:
      encoder2_kwargs['stride'] = 2
      encoder2_kwargs['kernels'] = [3,3]

    elif dsf2 == 1:
      encoder2_kwargs['stride'] = 1
      encoder2_kwargs['kernels'] = [3,1]
    else:
      raise NotImplementedError("Downsampling factor 2 not supported: {}".format(dsf2))


    encoder2 = get_encoder('conv_bn', **encoder2_kwargs)

  #################################
  # 2. Init VQ Codebooks
  #################################

  if cfg.model.backbone == 'vq2':
    vq_kwargs = {
    'emb_dim': cfg.model.vq.emb_dim1,
    'n_emb': cfg.model.vq.n_emb1,
    'beta': cfg.model.vq.beta,
    'tiled': cfg.model.vq.tiled,
    'ema_update': cfg.model.vq.ema_update,
    'random_restart': cfg.model.vq.random_restart
    }

    vq_layer = VectorQuantizer(cfg.model.vq.name, **vq_kwargs)

    vq_layer2 = None
    if num_latent_space == 2:
      vq2_kwargs = copy.deepcopy(vq_kwargs)
      vq2_kwargs['emb_dim'] = cfg.model.vq.emb_dim2
      vq2_kwargs['n_emb'] = cfg.model.vq.n_emb2
      vq_layer2 = VectorQuantizer(cfg.model.vq.name, **vq2_kwargs)
  else:

    codebook = VQMulti(
      n_embs=[cfg.model.vq.n_emb1, cfg.model.vq.n_emb2], 
      emb_dims=[cfg.model.vq.emb_dim1, cfg.model.vq.emb_dim2], 
      betas=[cfg.model.vq.beta] * num_latent_space, 
      levels=num_latent_space, 
      tiled=cfg.model.vq.tiled, 
      ema_update=cfg.model.vq.ema_update, 
      random_restart=cfg.model.vq.random_restart
      )
  #################################
  # 3. MH Dropout Block
  #################################

  mhd_layer = None
  if use_mhd:
    if num_latent_space == 2:
      mhd_inp_dim = cfg.model.vq.emb_dim2
    else:
      mhd_inp_dim = emb_dim1
    
    if cfg.model.mhd.bottleneck:
      hidden_dim = cfg.model.mhd.bottleneck_dim
    else:
      hidden_dim = mhd_inp_dim

    seq_dim = cfg.model.discrete_seq_len
    up_sample_ratio = int(np.ceil(seq_dim[0]**(1/2) - seq_dim[-1]**(1/2))) + 1

    mhd_kwargs = {
      'mask_type': cfg.model.mhd.mask_type,
      'num_latent_space': num_latent_space,
      'inp_dim': mhd_inp_dim,
      'hidden_dim': hidden_dim,
      'hypothese_bsz': cfg.model.mhd.hypothese_bsz,
      'out_dim': emb_dim1,
      'dist_reduce': cfg.model.mhd.dist_reduce,
      'loss_reduce': cfg.model.mhd.dist_reduce,
      'loss_reduce_dims': cfg.model.mhd.loss_reduce_dims,
      'dist_loss': cfg.model.mhd.dist_loss,
      'dropout_rate': cfg.model.mhd.dropout_rate,
      'decoder_cfg': cfg.model.mhd.decoder,
      'residual': num_latent_space == 2,
      'up_sample_ratio': up_sample_ratio,
      'debug': cfg.debug,
      'use_mhd_mask': cfg.model.mhd.use_mhd_mask
    }

    ## Setup config for decoder
    mhd_layer = get_mhd_layer(cfg.model.mhd.name, **mhd_kwargs)

  #################################
  # 5. Init model backbone
  #################################

  model_func = get_model_func(num_latent_space, cfg.model.backbone)

  seq_dim = copy.deepcopy(cfg.model.discrete_seq_len)
  seq_dim[0] = int(np.ceil(seq_dim[0] ** (1/2)))
  seq_dim[1] = int(np.ceil(seq_dim[1] ** (1/2)))
  model_kwargs = {
    'cfg': cfg.model,
    'num_latent_space': num_latent_space,
    'decoder': decoder,
    'mhd_layer': mhd_layer,
    'use_disc': use_disc,
    'seq_dim': seq_dim, 
    'eval_cfg': cfg.eval,
  }

  model_kwargs['encoder1'] = encoder
  model_kwargs['encoder2'] = encoder2

  if cfg.model.backbone == 'vq2':
    model_kwargs['dec_t'] = dec_t
    model_kwargs['vq_layer1'] = vq_layer
    model_kwargs['vq_layer2'] = vq_layer2

  else:
    model_kwargs['codebook'] = codebook
    
  model = model_func(**model_kwargs)
  
  opt = None
  params = get_vae_params(cfg, model)
  lr = cfg.train.optim.learning_rate
  betas = cfg.train.optim.betas
  opt = torch.optim.Adam(params, lr=lr, betas=betas)

  return model, opt

def init_disc_model(cfg):

  in_channels = cfg.data.in_channels

  #################################
  # 4. GAN
  #################################

  disc_layer = None
  disc_kwargs = {
    'disc_start': cfg.model.gan.start_step,
    'disc_in_channels': in_channels,
    'codebook_weight': cfg.model.gan.codebook_weight,
    'pixelloss_weight': cfg.model.gan.pixelloss_weight,
    'disc_num_layers': cfg.model.gan.disc_num_layers,
    'use_actnorm': cfg.model.gan.use_actnorm,
    'disc_loss': cfg.model.gan.disc_loss,
    'disc_ndf': cfg.model.gan.disc_ndf,
    'disc_factor': cfg.model.gan.disc_factor,
    'disc_weight': cfg.model.gan.disc_weight,
    'perceptual_weight': cfg.model.gan.perceptual_weight,
    'disc_conditional': cfg.model.gan.disc_conditional,
  }
  disc_layer = VQLPIPSWithDiscriminator(**disc_kwargs)

  disc_params = get_all_params(disc_layer)
  lr = cfg.train.optim.learning_rate
  betas = cfg.train.optim.betas

  disc_opt = None
  disc_opt = torch.optim.Adam(disc_params, lr=lr, betas=betas)
  
  return disc_layer, disc_opt


def get_vae_params(cfg, model):
  use_mhd = cfg.model.mhd.use
  num_latent_space = cfg.model.num_latent_space

  module = model.module if hasattr(model,'module') else model
  params = list(module.decoder.parameters())

  if num_latent_space == 1:
    params += list(module.encoder.parameters()) + list(module.vq_layer.parameters())

    if module.proj is not None:
      params += list(module.proj.parameters())

  elif num_latent_space == 2:
    if cfg.model.backbone == 'vq2':
      params += list(module.enc_b.parameters()) + list(module.vq_layer_b.parameters())
      params += list(module.enc_t.parameters()) + list(module.vq_layer_t.parameters())
      params += list(module.proj_t.parameters()) + list(module.proj_b.parameters())
      params += list(module.dec_t.parameters())

      if module.upsample_t is not None:
        params += list(module.upsample_t.parameters())

    else:
      params += list(module.enc_1.parameters()) + list(module.codebook.parameters())
      params += list(module.enc_2.parameters()) 
      if module.proj_1 is not None:
        params += list(module.proj_1.parameters()) + list(module.proj_2.parameters())

  if use_mhd:
      params += list(module.mhd_layer.parameters())
  
  return params

def get_all_params(model):
  return list(filter(lambda p: p.requires_grad, model.parameters()))

def init_seq_model(cfg, device_id):
  n_emb11 = cfg.model.vq.n_emb1
  n_emb12 = cfg.model.vq.n_emb2
  n_emb12 = n_emb11 if n_emb12 < 0 else n_emb12
  
  if cfg.model.backbone == 'vq2':
    n_emb11, n_emb12 = n_emb12, n_emb11
  
  if n_emb11 > n_emb12:
    n_class = max(n_emb12, n_emb11)
    n_emb11, n_emb12 = n_class, n_class
    
  seq_dim = copy.deepcopy(cfg.model.discrete_seq_len)
  seq_dim[0] = int(np.ceil(seq_dim[0] ** (1/2)))
  seq_dim[1] = int(np.ceil(seq_dim[1] ** (1/2)))

  if cfg.model.backbone == 'vq2':
    p2, p1 = seq_dim
  else:
    p1, p2 = seq_dim

  seq_shape1 = [p1, p1] #DIST / Bottom (in vq2)
  seq_shape2 = [p2, p2] #Dist / or top

  p_cfg = getattr(cfg.model.seq, cfg.model.seq.name.bottom)
  s_cfg = getattr(cfg.model.seq, cfg.model.seq.name.top)

  model = GPTWrapper(
    device_id=device_id,
    shape=seq_shape1,
    n_class=n_emb11,
    block_size=p_cfg.block_size, 
    n_layer=p_cfg.n_layer, 
    n_head=p_cfg.n_head, 
    n_embd=p_cfg.n_embd, 
    )

  cond_model = GPTWrapper(
    device_id=device_id,
    shape=seq_shape2,
    n_class=n_emb12,
    block_size=s_cfg.block_size, 
    n_layer=s_cfg.n_layer, 
    n_head=s_cfg.n_head, 
    n_embd=p_cfg.n_embd, #Use same sz emb due to conditioning
    )

  #############################################
  models = {}
  opts = {}

  params = get_all_params(model)
  lr = cfg.train.optim.learning_rate
  betas = cfg.train.optim.betas

  opt = None
  cond_opt = None
  opt = torch.optim.Adam(params, lr=lr, betas=betas)

  models['seq_base'] = model
  opts['seq_base'] = opt

  if cfg.model.num_latent_space > 1:
    #cond_model.cuda(device_id)
    cond_params = get_all_params(cond_model)
    cond_opt = torch.optim.Adam(cond_params, lr=lr, betas=betas)

    models['seq_cond'] = cond_model
    opts['seq_cond'] = cond_opt

  return models, opts


def get_mhd_layer(name, **kwargs):
  return MHDropoutNetRandom2D(**kwargs)

def get_encoder(model_name, **kwargs):
  if model_name == 'conv':
    return Encoder(**kwargs)
  elif model_name == 'conv_bn':
    return EncoderBN(**kwargs)
  elif model_name == 'conv_diffusion':
    return EncoderDiffusion(**kwargs)
  raise NotImplementedError("")

def get_decoder(model_name, **kwargs):
  if model_name == 'conv':
    return Decoder(**kwargs)
  elif model_name == 'conv_bn':
    return DecoderBN(**kwargs)
  elif model_name == 'conv_diffusion':
    return DecoderDiffusion(**kwargs)
  raise NotImplementedError("")

def get_model_func(num_latent_space, backbone='mvq'):
  assert num_latent_space in [1,2]
  if backbone == 'vq2':
    return VQ2VAE
  elif backbone == 'mvq':
    return MVQAE
  raise NotImplementedError("")