# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import copy import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as DDP from .mvq_model import MVQAE from .vq2_model import VQ2VAE from models.conv import ( Encoder, Decoder, EncoderBN, DecoderBN, get_coder_cfg ) from thirdparty.taming.diffusionmodules.model import Encoder as EncoderDiffusion from thirdparty.taming.diffusionmodules.model import Decoder as DecoderDiffusion from models.mh_dropout import ( MHDropoutNetRandom2D, ) from models.vq import VectorQuantizer, VQMulti from thirdparty.vqvae2.pixel_snail import PixelSNAIL from thirdparty.taming.losses.vqperceptual import VQLPIPSWithDiscriminator from .gpt_wrap import GPTWrapper def init_model(cfg, stage): use_distributed = cfg.torch_dist.use device_id = cfg.device_id use_disc = cfg.model.gan.use models, opts = {}, {} models['vae'], opts['vae'] = init_vae_model(cfg) if use_disc: models['disc'], opts['disc'] = init_disc_model(cfg) if stage == 'seq': m, o = init_seq_model(cfg, device_id) models = {**models, **m} opts = {**opts, **o} # if cfg.rank == 0: # for m_key, model in models.items(): # total_params = sum(p.numel() for p in model.parameters()) # print("Loaded Model: {:9} Parameters={}".format(m_key, total_params)) for s in models.keys(): if use_distributed and \ models.get(s) is not None: models[s] = to_distributed( model=models[s], device_id=device_id, ) elif device_id is not None and device_id is not 'cpu': if models[s] is not None: models[s].to(f'cuda:{device_id}') return models, opts def to_distributed(model, device_id): model.to(f'cuda:{device_id}') model = DDP(model, device_ids=[device_id]) return model def init_vae_model(cfg): debug = cfg.debug use_fp16 = cfg.fp16.use use_disc = cfg.model.gan.use emb_dim1 = cfg.model.vq.emb_dim1 emb_dim2 = cfg.model.vq.emb_dim2 use_mhd = cfg.model.mhd.use in_channels = cfg.data.in_channels num_latent_space = cfg.model.num_latent_space n_emb11 = cfg.model.vq.n_emb1 n_emb12 = cfg.model.vq.n_emb2 n_emb12 = n_emb11 if n_emb12 < 0 else n_emb12 if cfg.rank == 0: print("Model config : backbone={} codebook={} num_z={} mh-dropout={} discriminator={}".format( cfg.model.backbone, cfg.model.vq.name, num_latent_space, int(use_mhd), int(use_disc),)) ################################# # 1. Init encoder and decoders ################################# dsf1 = cfg.model.down_sampling_factor1 dsf2 = cfg.model.down_sampling_factor2 enc_cfg, enc2_cfg, dec_cfg, dec2_cfg = get_coder_cfg(cfg, dsf1) encoder_kwargs = {**enc_cfg} encoder_kwargs['in_channel'] = in_channels encoder_kwargs['out_channel'] = enc_cfg.channel d_in = emb_dim1 if dec_cfg.in_channel == 0 else dec_cfg.in_channel decoder_kwargs = {**dec_cfg} decoder_kwargs['in_channel'] = d_in decoder_kwargs['out_channel'] = in_channels if cfg.model.backbone == 'vq2': #Build top decoder dec_t_kwargs = copy.deepcopy(decoder_kwargs) dec_t_kwargs['in_channel'] = emb_dim2 dec_t_kwargs['out_channel'] = emb_dim1 if dsf2 == 16: dec_t_kwargs['stride'] = 4 dec_t_kwargs['kernels'] = [2,2,4] elif dsf2 == 4: dec_t_kwargs['stride'] = 2 dec_t_kwargs['kernels'] = [2,1] elif dsf2 == 1: dec_t_kwargs['stride'] = 1 dec_t_kwargs['kernels'] = [3,1] else: raise NotImplementedError("Downsampling factor 2 not supported: {}".format(dsf2)) dec_t = get_decoder(cfg.model.coder.name, **dec_t_kwargs) #Bottom decoder takes a concat of top and bottom if not use_mhd: decoder_kwargs['in_channel'] = emb_dim1 + emb_dim2 encoder = get_encoder(cfg.model.coder.name, **encoder_kwargs) decoder = get_decoder(cfg.model.coder.name, **decoder_kwargs) if num_latent_space == 2: encoder2_kwargs = { 'in_channel': in_channels, 'out_channel': enc2_cfg.channel, 'channel': enc2_cfg.channel, 'n_res_block': enc2_cfg.n_res_block, 'n_res_channel': enc2_cfg.n_res_channel, 'stride': enc2_cfg.stride, 'kernels': enc2_cfg.kernels, 'res_kernels': enc2_cfg.res_kernels, 'act': enc2_cfg.act } encoder2_kwargs['in_channel'] = enc_cfg.channel if dsf2 == 16: encoder2_kwargs['stride'] = 4 encoder2_kwargs['kernels'] = [4,3,3] elif dsf2 == 4: encoder2_kwargs['stride'] = 2 encoder2_kwargs['kernels'] = [3,3] elif dsf2 == 1: encoder2_kwargs['stride'] = 1 encoder2_kwargs['kernels'] = [3,1] else: raise NotImplementedError("Downsampling factor 2 not supported: {}".format(dsf2)) encoder2 = get_encoder('conv_bn', **encoder2_kwargs) ################################# # 2. Init VQ Codebooks ################################# if cfg.model.backbone == 'vq2': vq_kwargs = { 'emb_dim': cfg.model.vq.emb_dim1, 'n_emb': cfg.model.vq.n_emb1, 'beta': cfg.model.vq.beta, 'tiled': cfg.model.vq.tiled, 'ema_update': cfg.model.vq.ema_update, 'random_restart': cfg.model.vq.random_restart } vq_layer = VectorQuantizer(cfg.model.vq.name, **vq_kwargs) vq_layer2 = None if num_latent_space == 2: vq2_kwargs = copy.deepcopy(vq_kwargs) vq2_kwargs['emb_dim'] = cfg.model.vq.emb_dim2 vq2_kwargs['n_emb'] = cfg.model.vq.n_emb2 vq_layer2 = VectorQuantizer(cfg.model.vq.name, **vq2_kwargs) else: codebook = VQMulti( n_embs=[cfg.model.vq.n_emb1, cfg.model.vq.n_emb2], emb_dims=[cfg.model.vq.emb_dim1, cfg.model.vq.emb_dim2], betas=[cfg.model.vq.beta] * num_latent_space, levels=num_latent_space, tiled=cfg.model.vq.tiled, ema_update=cfg.model.vq.ema_update, random_restart=cfg.model.vq.random_restart ) ################################# # 3. MH Dropout Block ################################# mhd_layer = None if use_mhd: if num_latent_space == 2: mhd_inp_dim = cfg.model.vq.emb_dim2 else: mhd_inp_dim = emb_dim1 if cfg.model.mhd.bottleneck: hidden_dim = cfg.model.mhd.bottleneck_dim else: hidden_dim = mhd_inp_dim seq_dim = cfg.model.discrete_seq_len up_sample_ratio = int(np.ceil(seq_dim[0]**(1/2) - seq_dim[-1]**(1/2))) + 1 mhd_kwargs = { 'mask_type': cfg.model.mhd.mask_type, 'num_latent_space': num_latent_space, 'inp_dim': mhd_inp_dim, 'hidden_dim': hidden_dim, 'hypothese_bsz': cfg.model.mhd.hypothese_bsz, 'out_dim': emb_dim1, 'dist_reduce': cfg.model.mhd.dist_reduce, 'loss_reduce': cfg.model.mhd.dist_reduce, 'loss_reduce_dims': cfg.model.mhd.loss_reduce_dims, 'dist_loss': cfg.model.mhd.dist_loss, 'dropout_rate': cfg.model.mhd.dropout_rate, 'decoder_cfg': cfg.model.mhd.decoder, 'residual': num_latent_space == 2, 'up_sample_ratio': up_sample_ratio, 'debug': cfg.debug, 'use_mhd_mask': cfg.model.mhd.use_mhd_mask } ## Setup config for decoder mhd_layer = get_mhd_layer(cfg.model.mhd.name, **mhd_kwargs) ################################# # 5. Init model backbone ################################# model_func = get_model_func(num_latent_space, cfg.model.backbone) seq_dim = copy.deepcopy(cfg.model.discrete_seq_len) seq_dim[0] = int(np.ceil(seq_dim[0] ** (1/2))) seq_dim[1] = int(np.ceil(seq_dim[1] ** (1/2))) model_kwargs = { 'cfg': cfg.model, 'num_latent_space': num_latent_space, 'decoder': decoder, 'mhd_layer': mhd_layer, 'use_disc': use_disc, 'seq_dim': seq_dim, 'eval_cfg': cfg.eval, } model_kwargs['encoder1'] = encoder model_kwargs['encoder2'] = encoder2 if cfg.model.backbone == 'vq2': model_kwargs['dec_t'] = dec_t model_kwargs['vq_layer1'] = vq_layer model_kwargs['vq_layer2'] = vq_layer2 else: model_kwargs['codebook'] = codebook model = model_func(**model_kwargs) opt = None params = get_vae_params(cfg, model) lr = cfg.train.optim.learning_rate betas = cfg.train.optim.betas opt = torch.optim.Adam(params, lr=lr, betas=betas) return model, opt def init_disc_model(cfg): in_channels = cfg.data.in_channels ################################# # 4. GAN ################################# disc_layer = None disc_kwargs = { 'disc_start': cfg.model.gan.start_step, 'disc_in_channels': in_channels, 'codebook_weight': cfg.model.gan.codebook_weight, 'pixelloss_weight': cfg.model.gan.pixelloss_weight, 'disc_num_layers': cfg.model.gan.disc_num_layers, 'use_actnorm': cfg.model.gan.use_actnorm, 'disc_loss': cfg.model.gan.disc_loss, 'disc_ndf': cfg.model.gan.disc_ndf, 'disc_factor': cfg.model.gan.disc_factor, 'disc_weight': cfg.model.gan.disc_weight, 'perceptual_weight': cfg.model.gan.perceptual_weight, 'disc_conditional': cfg.model.gan.disc_conditional, } disc_layer = VQLPIPSWithDiscriminator(**disc_kwargs) disc_params = get_all_params(disc_layer) lr = cfg.train.optim.learning_rate betas = cfg.train.optim.betas disc_opt = None disc_opt = torch.optim.Adam(disc_params, lr=lr, betas=betas) return disc_layer, disc_opt def get_vae_params(cfg, model): use_mhd = cfg.model.mhd.use num_latent_space = cfg.model.num_latent_space module = model.module if hasattr(model,'module') else model params = list(module.decoder.parameters()) if num_latent_space == 1: params += list(module.encoder.parameters()) + list(module.vq_layer.parameters()) if module.proj is not None: params += list(module.proj.parameters()) elif num_latent_space == 2: if cfg.model.backbone == 'vq2': params += list(module.enc_b.parameters()) + list(module.vq_layer_b.parameters()) params += list(module.enc_t.parameters()) + list(module.vq_layer_t.parameters()) params += list(module.proj_t.parameters()) + list(module.proj_b.parameters()) params += list(module.dec_t.parameters()) if module.upsample_t is not None: params += list(module.upsample_t.parameters()) else: params += list(module.enc_1.parameters()) + list(module.codebook.parameters()) params += list(module.enc_2.parameters()) if module.proj_1 is not None: params += list(module.proj_1.parameters()) + list(module.proj_2.parameters()) if use_mhd: params += list(module.mhd_layer.parameters()) return params def get_all_params(model): return list(filter(lambda p: p.requires_grad, model.parameters())) def init_seq_model(cfg, device_id): n_emb11 = cfg.model.vq.n_emb1 n_emb12 = cfg.model.vq.n_emb2 n_emb12 = n_emb11 if n_emb12 < 0 else n_emb12 if cfg.model.backbone == 'vq2': n_emb11, n_emb12 = n_emb12, n_emb11 if n_emb11 > n_emb12: n_class = max(n_emb12, n_emb11) n_emb11, n_emb12 = n_class, n_class seq_dim = copy.deepcopy(cfg.model.discrete_seq_len) seq_dim[0] = int(np.ceil(seq_dim[0] ** (1/2))) seq_dim[1] = int(np.ceil(seq_dim[1] ** (1/2))) if cfg.model.backbone == 'vq2': p2, p1 = seq_dim else: p1, p2 = seq_dim seq_shape1 = [p1, p1] #DIST / Bottom (in vq2) seq_shape2 = [p2, p2] #Dist / or top p_cfg = getattr(cfg.model.seq, cfg.model.seq.name.bottom) s_cfg = getattr(cfg.model.seq, cfg.model.seq.name.top) model = GPTWrapper( device_id=device_id, shape=seq_shape1, n_class=n_emb11, block_size=p_cfg.block_size, n_layer=p_cfg.n_layer, n_head=p_cfg.n_head, n_embd=p_cfg.n_embd, ) cond_model = GPTWrapper( device_id=device_id, shape=seq_shape2, n_class=n_emb12, block_size=s_cfg.block_size, n_layer=s_cfg.n_layer, n_head=s_cfg.n_head, n_embd=p_cfg.n_embd, #Use same sz emb due to conditioning ) ############################################# models = {} opts = {} params = get_all_params(model) lr = cfg.train.optim.learning_rate betas = cfg.train.optim.betas opt = None cond_opt = None opt = torch.optim.Adam(params, lr=lr, betas=betas) models['seq_base'] = model opts['seq_base'] = opt if cfg.model.num_latent_space > 1: #cond_model.cuda(device_id) cond_params = get_all_params(cond_model) cond_opt = torch.optim.Adam(cond_params, lr=lr, betas=betas) models['seq_cond'] = cond_model opts['seq_cond'] = cond_opt return models, opts def get_mhd_layer(name, **kwargs): return MHDropoutNetRandom2D(**kwargs) def get_encoder(model_name, **kwargs): if model_name == 'conv': return Encoder(**kwargs) elif model_name == 'conv_bn': return EncoderBN(**kwargs) elif model_name == 'conv_diffusion': return EncoderDiffusion(**kwargs) raise NotImplementedError("") def get_decoder(model_name, **kwargs): if model_name == 'conv': return Decoder(**kwargs) elif model_name == 'conv_bn': return DecoderBN(**kwargs) elif model_name == 'conv_diffusion': return DecoderDiffusion(**kwargs) raise NotImplementedError("") def get_model_func(num_latent_space, backbone='mvq'): assert num_latent_space in [1,2] if backbone == 'vq2': return VQ2VAE elif backbone == 'mvq': return MVQAE raise NotImplementedError("")