# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import os import click import copy import warnings warnings.filterwarnings("ignore") import torch from state import State from dataset import init_dataloader from runner import get_runners from models import init_model, load_checkpoint from precompute_fid import precompute_fid_scores from utils import ( load_cfg, set_seeds, launch_dist_backend ) STAGE_KEYS = { 'vae': ['vae','disc'], 'seq': ['seq_base','seq_cond'] } @click.command() @click.option('--config_file','-c', default='default.yaml') @click.option('--mode','-m', default='train') @click.option('--stage','-s', default='vae') @click.option('--dataset','-d', default=None) @click.option('--seed','-se', default=None) @click.option('--work','-w', default='home') @click.option('--debug', '-bug', default=0) @click.option('--local_rank', '-lr', default=None) @click.option('--dist', '-ds', default=None) def main(config_file, mode, stage, work, dataset, seed, debug, local_rank, dist): assert mode in ['train','eval'] assert stage in STAGE_KEYS.keys() if local_rank is not None: os.environ["LOCAL_RANK"] = str(local_rank) ########################################### # Load Configurations cur_dir = os.path.dirname(os.path.realpath(__file__)) cfg_dir = os.path.join(cur_dir, 'config') cfg = load_cfg(config_file, cfg_dir) cfg.work_env = work cfg.cur_dir = cur_dir cfg.cur_stage = stage # This allows specification of different work environments cfg._exp_dir = getattr(cfg.exp_dir, work) cfg.data_path = getattr(cfg.data.path, work) ########################################## # Check experiment and data directory exist if cfg._exp_dir is None or cfg._exp_dir == '': cfg._exp_dir = os.path.join(cur_dir, 'exp') os.makedirs(cfg._exp_dir, exist_ok=True) print(f"Experiment directory not specified in config. Default: {cfg._exp_dir}") if cfg.data_path is None or cfg.data_path == '': cfg.data_path = os.path.join(cur_dir, 'data') os.makedirs(cfg.data_path, exist_ok=True) print(f"Data directory not specified in config. Default: {cfg.data_path}") ########################################## # Replace config with command line options (if any) if dist is not None: cfg.torch_dist.use = True if dist == '1' else False if seed is not None: cfg.seed = int(seed) # EXPERIMENT HYPERPARAMETERS down_sampling_factor1 = cfg.model.down_sampling_factor1 down_sampling_factor2 = cfg.model.down_sampling_factor2 #Calculate tokens per instance seq_len1 = (cfg.data.in_shape[0] // down_sampling_factor1) ** 2 seq_len2 = seq_len1 // down_sampling_factor2 cfg.model.discrete_seq_len = [seq_len1, seq_len2] dsl1, dsl2 = cfg.model.discrete_seq_len if cfg.model.mhd.hypothese_count < cfg.model.mhd.hypothese_bsz: cfg.model.mhd.hypothese_bsz = cfg.model.mhd.hypothese_count set_seeds(cfg.seed) if debug or cfg.debug: print("[Local rank={}] Debug mode activated.".format(local_rank)) os.environ['CUDA_LAUNCH_BLOCKING'] = '1' os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1' cfg.timeout = 600 cfg.debug = True cfg.num_workers = 0 if cfg.gpu.use: launch_dist_backend(cfg.torch_dist, debug=cfg.debug, timeout=cfg.timeout) if dataset is not None: cfg.data.name = dataset if not cfg.train.resume.is_resume: cfg.exp_name = '_'.join([ cfg.data.name, config_file.replace('.yaml',''), str(cfg.seed)] ) else: cfg.exp_name = cfg.train.resume.exp_name ########################################### # Create seperate ckpt and sample directories for each seq model (GPT/Pixel) cfg.stage_keys = copy.deepcopy(STAGE_KEYS) cfg.stage_key_dir = copy.deepcopy(STAGE_KEYS) cfg.seq_model_key = '-'.join([cfg.model.seq.name.bottom, cfg.model.seq.name.top]) cfg.stage_key_dir['seq'][0] = cfg.seq_model_key + '-' + cfg.stage_key_dir['seq'][0] cfg.stage_key_dir['seq'][1] = cfg.seq_model_key + '-' + cfg.stage_key_dir['seq'][1] ########################################### # Setup directories cfg.fid_dir = os.path.join(cfg.data_path, 'fid-stats') os.makedirs(cfg.fid_dir, exist_ok=True) cfg.save_dir = os.path.join(cfg._exp_dir, cfg.exp_name) os.makedirs(cfg.save_dir, exist_ok=True) cfg.ckpt_dir = os.path.join(cfg.save_dir, 'checkpoints') os.makedirs(cfg.ckpt_dir, exist_ok=True) cfg.ckpt_dirs = {} cfg.sample_dir = os.path.join(cfg.save_dir, 'samples') os.makedirs(cfg.sample_dir, exist_ok=True) cfg.sample_dirs = {} for s in [t for v in list(cfg.stage_key_dir.values()) for t in v]: _s = s.replace(cfg.seq_model_key + '-', '') stage_ckpt_dir = os.path.join(cfg.ckpt_dir, s) os.makedirs(stage_ckpt_dir, exist_ok=True) cfg.ckpt_dirs[_s] = stage_ckpt_dir if 'vae' in s or 'base' in s: stage_sample_dir = os.path.join(cfg.sample_dir, s) os.makedirs(stage_sample_dir, exist_ok=True) cfg.sample_dirs[_s] = stage_sample_dir ########################################### # Setup GPU assignments to config file n_gpus = torch.cuda.device_count() dist_avail = torch.distributed.is_available() if not dist_avail: raise SystemError("Torch Distributed Package Unavailable") if n_gpus == 0 or not cfg.gpu.use: cfg.gpu.use = False cfg.torch_dist.use = False cfg.rank = 'cpu' cfg.device_id = 'cpu' elif cfg.torch_dist.use: cfg.rank = int(os.environ["RANK"]) cfg.device_id = int(os.environ["LOCAL_RANK"]) cfg.device_ids = [cfg.device_id] if cfg.torch_dist.gpus_per_model <= 2: cfg.device_ids.append(cfg.device_id + 1) elif cfg.torch_dist.gpus_per_model > 2: raise NotImplementedError() torch.cuda.set_device(cfg.device_id) else: cfg.rank = 0 cfg.device_id = 0 torch.cuda.set_device(cfg.device_id) #Print experiment details to user if cfg.rank in ['cpu', 0]: print("Torch Distributed : {}".format(cfg.torch_dist.use)) if cfg.rank in ['cpu', 0]: print(f"Experiment Dir : {cfg.save_dir}") print(f"Dataset : {cfg.data.name}") print(f"Config : {config_file}") print("--downSampleFactor (F): {},{}".format(cfg.model.down_sampling_factor1, cfg.model.down_sampling_factor2)) print("--tokensPerInstance : {},{}".format(dsl1, dsl2)) print("--codebookSize (K) : {},{}".format(cfg.model.vq.n_emb1, cfg.model.vq.n_emb2)) print(f"--EMA : {cfg.model.vq.ema_update}") print(f"--randomRestarts : {cfg.model.vq.random_restart}") print(f"--MH-Dropout : {cfg.model.mhd.use_mhd_mask}") if cfg.model.mhd.use: print(f"--hypoCount (J) : {cfg.model.mhd.hypothese_count}") print(f"--hypoBatchSz : {cfg.model.mhd.hypothese_bsz}") cfg.batch_size = getattr(cfg.batch_size_per_gpu, cfg.cur_stage) ### Check for FID file. If not exist then precompute fid_file = os.path.join(cfg.fid_dir, cfg.data.name + '-train.npz') if not os.path.exists(fid_file): if cfg.rank in ['cpu', 0]: print(f"FID file does not exist: {fid_file}\nStarting precompute of FID statistics.") precompute_fid_scores( data_dir=cfg.data_path, dataset=cfg.data.name, batch_size=cfg.batch_size_per_gpu.vae, device=f'cuda:{cfg.device_id}' if cfg.rank != 'cpu' else 'cpu' ) if cfg.torch_dist.use: dist.barrier() #Initialise model, optimisers and dataset models, opts = init_model(cfg, stage) train_loader, val_loader = init_dataloader( cfg.data.name, cfg.data_path, cfg.batch_size, cfg.num_workers, cfg.gpu.use, cfg.torch_dist.use ) state = State(models, opts=opts, rank=cfg.rank, stage=stage, stage_keys=cfg.stage_keys, mode=mode) state = load_checkpoint( state, cfg.ckpt_dirs, cfg.device_id, cfg.rank, cfg.torch_dist.use, mode=mode) runner = get_runners(cfg, stage) runner.global_step = state.global_step runner.metrics = state.metrics eval_freq = cfg.train.intervals.eval gen_freq = cfg.train.intervals.gen hypothese_count = cfg.model.mhd.hypothese_count start_epoch = state.epoch + 1 if mode in ['eval', 'val']: total_epochs = start_epoch + 1 elif stage == 'vae': total_epochs = cfg.train.epochs.vae elif stage == 'seq': if start_epoch < cfg.train.epochs.vae: start_epoch = cfg.train.epochs.vae total_epochs = cfg.train.epochs.seq + cfg.train.epochs.vae else: raise ValueError(f"Invalid stage {stage}") runner.logger.info(f"=> stage: {stage}, mode: {mode}, start_epoch: {start_epoch}, total_epochs: {total_epochs}, best_score: {state.best_score}") # Start training / evaluation for epoch in range(start_epoch, total_epochs): is_best = False runner.epoch = epoch runner.best_score = state.best_score state.epoch = runner.epoch state.global_step = runner.global_step state.metrics = runner.metrics if cfg.torch_dist.use: train_loader.batch_sampler.sampler.set_epoch(epoch) if mode == 'train': runner.train(train_loader, models, opts, hypothese_count=hypothese_count) if mode == 'eval' or (epoch % eval_freq == 0 and epoch > cfg.train.epochs.warmup): gen_samples = mode == 'eval' or (epoch % gen_freq == 0) best_score = runner.eval( val_loader, models, hypothese_count=hypothese_count, epoch=epoch, gen_samples=gen_samples) if stage not in state.best_score: state.best_score[stage] = float('inf') is_best = best_score < state.best_score[stage] state.best_score[stage] = min(best_score, state.best_score[stage]) if is_best: runner.logger.info("Score update | Score: {:.4f} Best Score: {:.4f} | is_best: {}".format(best_score, state.best_score[stage], is_best)) if mode == 'train' and (cfg.rank == 0) and epoch > 0: #Save checkpoints if is_best: tag = 'best_{}'.format(round(epoch, -2)) state.save(ckpt_dirs=cfg.ckpt_dirs, tag=tag, save_latest=False) tag = 'last_{}'.format(round(epoch, -2)) state.save(ckpt_dirs=cfg.ckpt_dirs, tag=tag, save_latest=True) if __name__ == "__main__": main()