honeyplotnet / main.py
main.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import os
import yaml
import click
import time

from state import State
from dataset import init_dataloader 
from runner import get_runners 
from models import init_model, load_checkpoint 
from fid import init_fid_model

from utils import (
  load_cfg,
  set_seeds,
  launch_dist_backend,
  start_debug_mode,
  setup_gpu_cfg, 
)

STAGES = ['continuous', 'seq']

@click.command()
@click.option('--config_file','-c', default='default.yaml', help='Configuration files in config folder')
@click.option('--mode','-m', default='train', help='Runner mode. Select from ["train","eval"]')
@click.option('--stage','-s', default='continuous', help='Training stages. Select from ["continuous","seq"]')
@click.option('--work','-w', default='home', help='Work environment')
@click.option('--distributed', '-d', default=None, help='Deactivate Pytorch distributed package')
@click.option('--debug', '-bug', default=0, help='Activates debug mode')
@click.option('--local_rank', '-lr', default=None, help='For distributed.')
def main(config_file, mode, stage, work, debug, distributed, local_rank):

  assert stage in STAGES
  assert mode in ['train','eval','generate'], mode
  
  if local_rank is not None:
    os.environ["LOCAL_RANK"] = str(local_rank)

  ###########################################
  # Load Configurations 
  cur_dir = os.path.dirname(os.path.realpath(__file__))
  cfg_dir = os.path.join(cur_dir, 'config')

  cfg = load_cfg(config_file, cfg_dir)
  set_seeds(cfg.seed)
  cfg.work_env = work
  cfg.cur_dir = cur_dir
  cfg.cur_stage = stage
  cfg.batch_size = getattr(cfg.batch_size_per_gpu, stage)

  # This allows specification of different work environments
  cfg._exp_dir = getattr(cfg.exp_dir, work)
  cfg.data_path = getattr(cfg.data.path, work)

  # Automatically remove fid in opt_mode==1
  opt_mode = int(cfg.model.seq.opt_mode)
  if mode == 'train' and stage == 'seq':
    if opt_mode == 1:
      cfg.eval.fid = False
    elif opt_mode == 2:
      cfg.eval.fid = True

  ##########################################
  # Check experiment and data directory exist
  if cfg._exp_dir in [None,'']:
    cfg._exp_dir = os.path.join(cur_dir, 'exp')
    os.makedirs(cfg._exp_dir, exist_ok=True)
    print(f"Experiment directory not specified in config. Default: {cfg._exp_dir}")
  
  if cfg.data_path in [None,'']:
    cfg.data_path = os.path.join(cur_dir, 'data')
    os.makedirs(cfg.data_path, exist_ok=True)
    print(f"Data directory not specified in config.       Default: {cfg.data_path}")
  
  # Check active model list in config file.
  if stage not in cfg.model.active:
    cfg.model.active += [stage]

  ##########################################
  # Replace config with command line options (if any)

  if distributed is not None:
    cfg.torch_dist.use = True if distributed == '1' else False

  if debug or cfg.debug:
    cfg = start_debug_mode(cfg)

  if cfg.use_gpu:
    launch_dist_backend(cfg.torch_dist, debug=cfg.debug, timeout=cfg.timeout)

  if cfg.exp_name is None:
    cfg.exp_name = '_'.join([config_file.replace('.yaml',''), str(cfg.seed)])
  
  ###########################################
  # Setup directories
  cfg.save_dir = os.path.join(cfg._exp_dir, cfg.exp_name)
  os.makedirs(cfg.save_dir, exist_ok=True)

  #Save cfg into directory
  cfg_fn = os.path.join(cfg.save_dir, f'config_{int(time.time())}.yaml')
  with open(cfg_fn, 'w') as file:
    yaml.dump(cfg, file)

  cfg.ckpt_dirs = {}
  cfg.sample_dirs = {}

  #Creates new data directories
  for dir_name, cfg_attr, cfg_base in [
      ('fid_stats', 'fid_dir','data_path'),
      ('cache', 'cache_dir','data_path'),
      ('tensorboard','tb_dir', '_exp_dir'), 
      ('checkpoints', 'ckpt_dir','save_dir'),
      ('samples', 'sample_dir','save_dir')
      ]:
    new_path = os.path.join(cfg[cfg_base], dir_name)
    os.makedirs(new_path, exist_ok=True)
    setattr(cfg, cfg_attr, new_path) 
  os.environ['TRANSFORMERS_CACHE'] = cfg.cache_dir

  for s in STAGES:
    stage_ckpt_dir = os.path.join(cfg.ckpt_dir, s)
    os.makedirs(stage_ckpt_dir, exist_ok=True)
    cfg.ckpt_dirs[s] = stage_ckpt_dir

    stage_sample_dir = os.path.join(cfg.sample_dir, s)
    os.makedirs(stage_sample_dir, exist_ok=True)
    cfg.sample_dirs[s] = stage_sample_dir

  gen_sample_dir = os.path.join(cfg.sample_dir, 'generate')
  os.makedirs(gen_sample_dir, exist_ok=True)

  cfg.sample_dirs['generate'] = {}
  cfg.sample_dirs['generate']['base'] = gen_sample_dir
  
  for d in ['json','mpl']:
    new_dir = os.path.join(gen_sample_dir, d)
    os.makedirs(new_dir, exist_ok=True)
    cfg.sample_dirs['generate'][d] = new_dir

  ###########################################
  cfg = setup_gpu_cfg(cfg)

  #Initialise model, optimisers and dataset
  models, tokenizers, opts, schs = init_model(
    cfg, mode, stage, cfg.device_id)

  train_loader, val_loader = init_dataloader(cfg, mode, stage, models, tokenizers)

  state = State(models, tokenizers, opts, schs, rank=cfg.rank, mode=mode, stage=stage)
  
  state = load_checkpoint(
    state, 
    cfg.ckpt_dirs, 
    cfg.device_id, cfg.rank, 
    cfg.torch_dist.use)
  
  #Initialize pre-trained fid model
  fid_stats = None
  if cfg.eval.fid:
    models['fid'], fid_stats = init_fid_model(cfg, load_path=cfg.fid_dir, device_id=cfg.device_id)
  
  runner = get_runners(cfg, stage, mode)
  runner.global_step = state.global_step
  runner.metrics = state.metrics
  runner.fid_stats = fid_stats

  eval_freq = cfg.train.intervals.eval
  start_epoch = state.epoch + 1

  if mode in ['eval', 'generate']: 
    total_epochs = start_epoch + 1
  elif stage in ['continuous']:
    total_epochs = int(getattr(cfg.train.epochs,stage))
  elif stage == 'seq':
    if start_epoch < cfg.train.epochs.continuous:
      start_epoch = cfg.train.epochs.continuous
    total_epochs = cfg.train.epochs.seq + cfg.train.epochs.continuous
  else:
    total_epochs = cfg.train.epochs.total
  
  if mode == 'train' and stage == 'seq':
    if opt_mode == 0:
      pass
    elif opt_mode == 1:
      tasks = ['categorical','series_name','axis','caption']
      train_loader.dataset.set_tasks(tasks)
      val_loader.dataset.set_tasks(tasks)
    elif opt_mode == 2:
      train_loader.dataset.set_tasks('data')
      val_loader.dataset.set_tasks('data')
    else:
      raise

  score_stage = stage + f'_{str(opt_mode)}' if stage == 'seq' else stage
  if score_stage not in state.best_score and stage == 'seq':
    state.best_score[score_stage] = 0.0 if score_stage == 'seq_1' else float('inf')

  if cfg.rank in ['cpu', 0]:
    runner.logger.info("GPU                   : {}".format(cfg.use_gpu))
    runner.logger.info("Torch Distributed     : {}".format(cfg.torch_dist.use))
    runner.logger.info("Stage                 : {}".format(stage))
    runner.logger.info("Mode                  : {}".format(mode))
    runner.logger.info("Experiment Dir        : {}".format(cfg.save_dir))
    runner.logger.info("Dataset               : {}".format(cfg.data.name))
    runner.logger.info("FID Test              : {}".format(cfg.eval.fid))
    runner.logger.info("Config                : {}".format(config_file))
    runner.logger.info("--gan                 : {}".format(cfg.model.continuous_data.disc.use))
    runner.logger.info("--mhd                 : {}".format(cfg.model.continuous_data.mhd.use))
    runner.logger.info(f"start_epoch: {start_epoch}, total_epochs: {total_epochs}, best_score[{score_stage}]: {state.best_score.get(score_stage)}")
    runner.logger.info("Active components >>")
    runner.logger.info("model      : {}".format([name for name, m in models.items() if m is not None]))
    runner.logger.info("opt        : {}".format([name for name, m in opts.items() if m is not None]))
    runner.logger.info("tokenizers : {}".format([name for name, m in tokenizers.items() if m is not None]))

    if hasattr(train_loader.dataset, 'tasks'):
      runner.logger.info("tasks      : {}".format(train_loader.dataset.tasks))

  for epoch in range(start_epoch, total_epochs):
    is_best = False
    best_score = None
    runner.epoch = epoch
    state.epoch = runner.epoch
    state.global_step = runner.global_step
    state.metrics = runner.metrics

    if cfg.torch_dist.use:
      train_loader.batch_sampler.sampler.set_epoch(epoch)

    if mode == 'generate':
      runner.generate(val_loader, models, tokenizers)
    elif mode == 'train':
      runner.train(train_loader, models, tokenizers, opts, schs)
      
    if mode == 'eval' or (epoch % eval_freq == 0 and epoch > cfg.train.epochs.warmup):
      results = runner.eval(
        val_loader, models, tokenizers, 
        metric_key_prefix='eval', epoch=epoch, 
        create_sample=(mode == 'eval'))

      best_score = results['score']
      if stage == 'seq' and opt_mode == 1:
        is_best = best_score > state.best_score[score_stage]
        state.best_score[score_stage] = max(best_score, state.best_score[score_stage])
      else:
        is_best = best_score < state.best_score[score_stage]
        state.best_score[score_stage] = min(best_score, state.best_score[score_stage])
        
      if best_score is not None:
        runner.logger.info("Score update | best: {:.4f} last: {:.4f} [{}]".format(state.best_score[stage], best_score, is_best))

    if mode == 'train' and cfg.rank == 0:
      if is_best:
        tag = 'best_{}'.format(round(epoch, -1))
        state.save(ckpt_dirs=cfg.ckpt_dirs, tag=tag, save_latest=False)
      
      tag = 'last_{}'.format(round(epoch, -1))
      state.save(ckpt_dirs=cfg.ckpt_dirs, tag=tag, save_latest=True)
  

if __name__ == "__main__":
  main()