mvq / main.py
main.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import os
import click
import copy

import warnings
warnings.filterwarnings("ignore")

import torch

from state import State
from dataset import init_dataloader 
from runner import get_runners 
from models import init_model, load_checkpoint 

from precompute_fid import precompute_fid_scores

from utils import (
  load_cfg,
  set_seeds,
  launch_dist_backend
)

STAGE_KEYS = {
  'vae': ['vae','disc'],
  'seq': ['seq_base','seq_cond']
}

@click.command()
@click.option('--config_file','-c', default='default.yaml')
@click.option('--mode','-m', default='train')
@click.option('--stage','-s', default='vae')
@click.option('--dataset','-d', default=None)
@click.option('--seed','-se', default=None)
@click.option('--work','-w', default='home')
@click.option('--debug', '-bug', default=0)
@click.option('--local_rank', '-lr', default=None)
@click.option('--dist', '-ds', default=None)
def main(config_file, mode, stage, work, dataset, seed, debug, local_rank, dist):
  assert mode in ['train','eval']
  assert stage in STAGE_KEYS.keys()

  if local_rank is not None:
    os.environ["LOCAL_RANK"] = str(local_rank)

  ###########################################
  # Load Configurations 
  cur_dir = os.path.dirname(os.path.realpath(__file__))
  cfg_dir = os.path.join(cur_dir, 'config')

  cfg = load_cfg(config_file, cfg_dir)
  cfg.work_env = work
  cfg.cur_dir = cur_dir
  cfg.cur_stage = stage

  # This allows specification of different work environments
  cfg._exp_dir = getattr(cfg.exp_dir, work)
  cfg.data_path = getattr(cfg.data.path, work)
  
  ##########################################
  # Check experiment and data directory exist
  if cfg._exp_dir is None or cfg._exp_dir == '':
    cfg._exp_dir = os.path.join(cur_dir, 'exp')
    os.makedirs(cfg._exp_dir, exist_ok=True)
    print(f"Experiment directory not specified in config. Default: {cfg._exp_dir}")
  
  if cfg.data_path is None or cfg.data_path == '':
    cfg.data_path = os.path.join(cur_dir, 'data')
    os.makedirs(cfg.data_path, exist_ok=True)
    print(f"Data directory not specified in config.       Default: {cfg.data_path}")
  
  ##########################################
  # Replace config with command line options (if any)

  if dist is not None:
    cfg.torch_dist.use = True if dist == '1' else False

  if seed is not None:
    cfg.seed = int(seed)

  # EXPERIMENT HYPERPARAMETERS
  down_sampling_factor1 = cfg.model.down_sampling_factor1
  down_sampling_factor2 = cfg.model.down_sampling_factor2
  
  #Calculate tokens per instance
  seq_len1 = (cfg.data.in_shape[0] // down_sampling_factor1) ** 2
  seq_len2 = seq_len1 // down_sampling_factor2

  cfg.model.discrete_seq_len = [seq_len1, seq_len2]

  dsl1, dsl2 = cfg.model.discrete_seq_len

  if cfg.model.mhd.hypothese_count < cfg.model.mhd.hypothese_bsz:
    cfg.model.mhd.hypothese_bsz = cfg.model.mhd.hypothese_count
  
  set_seeds(cfg.seed)

  if debug or cfg.debug:
    print("[Local rank={}] Debug mode activated.".format(local_rank))
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
    os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'
    cfg.timeout = 600
    cfg.debug = True
    cfg.num_workers = 0

  if cfg.gpu.use:
    launch_dist_backend(cfg.torch_dist, debug=cfg.debug, timeout=cfg.timeout)

  if dataset is not None:
    cfg.data.name = dataset

  if not cfg.train.resume.is_resume:
    cfg.exp_name = '_'.join([
      cfg.data.name, 
      config_file.replace('.yaml',''), 
      str(cfg.seed)]
      )
  else:
    cfg.exp_name = cfg.train.resume.exp_name
  

  ###########################################
  # Create seperate ckpt and sample directories for each seq model (GPT/Pixel)
  cfg.stage_keys = copy.deepcopy(STAGE_KEYS)
  cfg.stage_key_dir = copy.deepcopy(STAGE_KEYS)
  cfg.seq_model_key = '-'.join([cfg.model.seq.name.bottom, cfg.model.seq.name.top])

  cfg.stage_key_dir['seq'][0] = cfg.seq_model_key + '-' + cfg.stage_key_dir['seq'][0] 
  cfg.stage_key_dir['seq'][1] = cfg.seq_model_key + '-' + cfg.stage_key_dir['seq'][1] 
  
  ###########################################
  # Setup directories

  cfg.fid_dir = os.path.join(cfg.data_path, 'fid-stats')
  os.makedirs(cfg.fid_dir, exist_ok=True)

  cfg.save_dir = os.path.join(cfg._exp_dir, cfg.exp_name)
  os.makedirs(cfg.save_dir, exist_ok=True)

  cfg.ckpt_dir = os.path.join(cfg.save_dir, 'checkpoints')
  os.makedirs(cfg.ckpt_dir, exist_ok=True)
  cfg.ckpt_dirs = {}

  cfg.sample_dir =  os.path.join(cfg.save_dir, 'samples')
  os.makedirs(cfg.sample_dir, exist_ok=True)
  cfg.sample_dirs = {}

  for s in [t for v in list(cfg.stage_key_dir.values()) for t in v]:
    _s = s.replace(cfg.seq_model_key + '-', '')
    stage_ckpt_dir = os.path.join(cfg.ckpt_dir, s)
    os.makedirs(stage_ckpt_dir, exist_ok=True)
    cfg.ckpt_dirs[_s] = stage_ckpt_dir

    if 'vae' in s or 'base' in s:
      stage_sample_dir = os.path.join(cfg.sample_dir, s)
      os.makedirs(stage_sample_dir, exist_ok=True)
      cfg.sample_dirs[_s] = stage_sample_dir

  ###########################################

  # Setup GPU assignments to config file
  n_gpus = torch.cuda.device_count()
  dist_avail = torch.distributed.is_available()
  if not dist_avail:
    raise SystemError("Torch Distributed Package Unavailable")

  if n_gpus == 0 or not cfg.gpu.use:
    cfg.gpu.use = False
    cfg.torch_dist.use = False
    
    cfg.rank = 'cpu'
    cfg.device_id = 'cpu'

  elif cfg.torch_dist.use:
    cfg.rank = int(os.environ["RANK"])
    cfg.device_id = int(os.environ["LOCAL_RANK"])
    cfg.device_ids = [cfg.device_id]

    if cfg.torch_dist.gpus_per_model <= 2:
      cfg.device_ids.append(cfg.device_id + 1)
    elif cfg.torch_dist.gpus_per_model > 2:
      raise NotImplementedError()

    torch.cuda.set_device(cfg.device_id)
  else:
    cfg.rank = 0
    cfg.device_id = 0
    torch.cuda.set_device(cfg.device_id)

  #Print experiment details to user
  if cfg.rank in ['cpu', 0]:
    print("Torch Distributed     : {}".format(cfg.torch_dist.use))
  if cfg.rank in ['cpu', 0]:
    print(f"Experiment Dir        : {cfg.save_dir}")
    print(f"Dataset               : {cfg.data.name}")
    print(f"Config                : {config_file}")
    print("--downSampleFactor (F): {},{}".format(cfg.model.down_sampling_factor1, cfg.model.down_sampling_factor2))
    print("--tokensPerInstance   : {},{}".format(dsl1, dsl2))
    print("--codebookSize (K)    : {},{}".format(cfg.model.vq.n_emb1, cfg.model.vq.n_emb2))
    print(f"--EMA                 : {cfg.model.vq.ema_update}")
    print(f"--randomRestarts      : {cfg.model.vq.random_restart}")
    print(f"--MH-Dropout          : {cfg.model.mhd.use_mhd_mask}")

    if cfg.model.mhd.use:
      print(f"--hypoCount    (J)    : {cfg.model.mhd.hypothese_count}")
      print(f"--hypoBatchSz         : {cfg.model.mhd.hypothese_bsz}")


  cfg.batch_size = getattr(cfg.batch_size_per_gpu, cfg.cur_stage)

  ### Check for FID file. If not exist then precompute
  fid_file = os.path.join(cfg.fid_dir, cfg.data.name + '-train.npz')

  if not os.path.exists(fid_file):
    if cfg.rank in ['cpu', 0]:
      print(f"FID file does not exist: {fid_file}\nStarting precompute of FID statistics.")
      precompute_fid_scores(
        data_dir=cfg.data_path, 
        dataset=cfg.data.name, 
        batch_size=cfg.batch_size_per_gpu.vae,
        device=f'cuda:{cfg.device_id}' if cfg.rank != 'cpu' else 'cpu'
        )

    if cfg.torch_dist.use:
      dist.barrier()

  
  #Initialise model, optimisers and dataset
  models, opts = init_model(cfg, stage)

  train_loader, val_loader = init_dataloader(
        cfg.data.name, cfg.data_path, cfg.batch_size, 
        cfg.num_workers, cfg.gpu.use, cfg.torch_dist.use
    )

  state = State(models, opts=opts, rank=cfg.rank, stage=stage, stage_keys=cfg.stage_keys, mode=mode)
  state = load_checkpoint(
    state, 
    cfg.ckpt_dirs, 
    cfg.device_id, 
    cfg.rank, 
    cfg.torch_dist.use, 
    mode=mode)

  runner = get_runners(cfg, stage)
  runner.global_step = state.global_step
  runner.metrics = state.metrics

  eval_freq = cfg.train.intervals.eval
  gen_freq = cfg.train.intervals.gen
  hypothese_count = cfg.model.mhd.hypothese_count

  start_epoch = state.epoch + 1

  if mode in ['eval', 'val']: 
    total_epochs = start_epoch + 1
  elif stage == 'vae':
    total_epochs = cfg.train.epochs.vae 
  elif stage == 'seq':
    if start_epoch < cfg.train.epochs.vae:
      start_epoch = cfg.train.epochs.vae
    total_epochs = cfg.train.epochs.seq + cfg.train.epochs.vae
  else:
    raise ValueError(f"Invalid stage {stage}")

  runner.logger.info(f"=> stage: {stage}, mode: {mode}, start_epoch: {start_epoch}, total_epochs: {total_epochs}, best_score: {state.best_score}")

  # Start training / evaluation
  for epoch in range(start_epoch, total_epochs):
    is_best = False
    runner.epoch = epoch
    runner.best_score = state.best_score
    state.epoch = runner.epoch
    state.global_step = runner.global_step
    state.metrics = runner.metrics

    if cfg.torch_dist.use:
      train_loader.batch_sampler.sampler.set_epoch(epoch)

    if mode == 'train':
      runner.train(train_loader, models, opts, hypothese_count=hypothese_count)

    if mode == 'eval' or (epoch % eval_freq == 0 and epoch > cfg.train.epochs.warmup):
      gen_samples = mode == 'eval' or (epoch % gen_freq == 0)

      best_score = runner.eval(
        val_loader, models, hypothese_count=hypothese_count, epoch=epoch, gen_samples=gen_samples)

      if stage not in state.best_score:
         state.best_score[stage] = float('inf')
        
      is_best = best_score < state.best_score[stage]
      state.best_score[stage] = min(best_score, state.best_score[stage])

      if is_best:
        runner.logger.info("Score update | Score: {:.4f}  Best Score: {:.4f} | is_best: {}".format(best_score, state.best_score[stage], is_best))

    if mode == 'train' and (cfg.rank == 0) and epoch > 0:
      #Save checkpoints
      if is_best:
        tag = 'best_{}'.format(round(epoch, -2))
        state.save(ckpt_dirs=cfg.ckpt_dirs, tag=tag, save_latest=False)
      
      tag = 'last_{}'.format(round(epoch, -2))
      state.save(ckpt_dirs=cfg.ckpt_dirs,  tag=tag, save_latest=True)

if __name__ == "__main__":
  main()