# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # # Forked from: # https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py # --------------------------------------------------------------- import os import glob import random import torch import numpy as np import traceback class State: def __init__(self, models, stage_keys, opts=None, schs=None, rank=0, mode='train', stage='vae'): self.epoch = -1 self.global_step = 0 self.rank = rank self.best_score = {'vae': float('inf'), 'seq': float('inf')} self.models = models self.opts = opts if opts is not None else {} self.schs = schs if schs is not None else {} self.scaler = {} self.metrics = [] self.stage_keys = stage_keys self.all_stage_keys = [t for v in list(stage_keys.values()) for t in v] self.mode = mode self.cur_stage = stage def _capture_snapshot(self, stage): snap = {} if self.models.get(stage) is not None: snap['{}_model'.format(stage)] = self.models[stage].state_dict() if self.opts.get(stage) is not None: snap['{}_opt'.format(stage)] = self.opts[stage].state_dict() if self.schs.get(stage) is not None: snap['{}_schs'.format(stage)] = self.schs[stage].state_dict() if self.scaler.get(stage) is not None: snap['{}_scaler'.format(stage)] = self.scaler[stage].state_dict() return snap def capture_snapshot(self, stage=None): if stage is None: stages = self.all_stage_keys else: stages = [stage] snapshot = { 'epoch': self.epoch, 'global_step': self.global_step, 'best_score': self.best_score, 'metrics': self.metrics, 'rng': self.get_rng() } for stage in stages: snap = self._capture_snapshot(stage) snapshot = {**snapshot, **snap} return snapshot def _apply_snapshot(self, obj, stage): obj_name = '{}_model'.format(stage) if obj.get(obj_name) is not None and self.models.get(stage) is not None: ddp_ckpt = 'module.' in list(obj[obj_name].keys())[0] ddp_model = self.models[stage].__class__.__name__ == 'DistributedDataParallel' if ddp_ckpt and not ddp_model: obj[obj_name] = self.from_ddp_ckpt(obj[obj_name]) elif not ddp_ckpt and ddp_model: obj[obj_name] = self.to_ddp_ckpt(obj[obj_name]) self.models[stage].load_state_dict(obj[obj_name]) obj_name = '{}_tokenizer'.format(stage) if obj.get(obj_name) is not None and self.tokenizers.get(stage) is not None: try: self.tokenizers[stage].load_state_dict(obj[obj_name]) except: print("Failed loading tokenizer for {} ".format(stage)) if stage == self.cur_stage: obj_name = '{}_opt'.format(stage) if obj.get(obj_name) is not None and self.opts.get(stage) is not None: try: self.opts[stage].load_state_dict(obj[obj_name]) except: print("Failed loading opt for {} ".format(stage)) obj_name = '{}_schs'.format(stage) if obj.get(obj_name) is not None and self.opts.get(stage) is not None: try: self.schs[stage].load_state_dict(obj[obj_name]) except: print("Failed loading schs for {} ".format(stage)) obj_name = '{}_scaler'.format(stage) if obj.get(obj_name) is not None and self.opts.get(stage) is not None: try: self.scaler[stage].load_state_dict(obj[obj_name]) except: print("Failed loading schs for {} ".format(stage)) def apply_snapshot(self, obj, stage=None): if stage is None: stages = self.all_stage_keys else: stages = [stage] for stage in stages: self._apply_snapshot(obj, stage) def get_rng(self): rng_states = { "python": random.getstate(), "numpy": np.random.get_state(), "cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): rng_states["cuda"] = torch.cuda.random.get_rng_state('cuda').cpu() return rng_states def set_rng(self, rng): if rng is not None: random.setstate(rng["python"]) np.random.set_state(rng["numpy"]) try: #torch.random.set_rng_state(rng["cpu"]) if torch.cuda.is_available(): torch.cuda.random.set_rng_state(rng["cuda"].cpu()) except: print("Failed to set random seed for torch") print(traceback.format_exc()) pass def load(self, ckpt_dirs, device_id, mode): for model_key, ckpt_dir in ckpt_dirs.items(): client_sd = None ckpt_files = glob.glob(ckpt_dir + '/*') if len(ckpt_files) > 0: torch_ckpts = [] if model_key not in self.stage_keys[self.cur_stage] or mode == 'eval': torch_ckpts = sorted([f for f in ckpt_files if f.endswith('snapshot.pth') and f.split('/')[-1].startswith('best')]) if len(torch_ckpts) == 0: torch_ckpts = sorted([f for f in ckpt_files if f.endswith('snapshot.pth') and f.split('/')[-1].startswith('last')]) if len(torch_ckpts) > 0: f = torch_ckpts[-1] if self.rank == 0: print("[stage={}] Loading torch ckpt: {}".format(model_key, f)) client_sd = torch.load(f, map_location=f"cuda:{device_id}") self.apply_snapshot(client_sd, stage=model_key) if client_sd is not None and (client_sd['epoch'] > self.epoch or model_key == self.cur_stage): self.set_rng(client_sd.get('rng')) self.global_step = client_sd['global_step'] self.epoch = client_sd['epoch'] self.best_score = client_sd['best_score'] self.metrics = client_sd['metrics'] #else: #if self.rank == 0: print("client_sd is None! state has not been updated") def save(self, ckpt_dirs, tag=None, save_latest=True): for model_key, ckpt_dir in ckpt_dirs.items(): if model_key in self.stage_keys[self.cur_stage]: if tag is not None: ckpt_fn = '{}_snapshot.pth'.format(tag) else: ckpt_fn = '{}_snapshot.pth'.format(self.global_step) ckpt_path = os.path.join(ckpt_dir, ckpt_fn) torch.save(self.capture_snapshot(model_key), ckpt_path) def from_ddp_ckpt(self, ddp_snapshot): # Convert from ddp ckpt to nonddp model. snapshot = {} for k,v in ddp_snapshot.items(): new_k = k.replace('module.','') snapshot[new_k] = v return snapshot def to_ddp_ckpt(self, snapshot): ddp_snapshot = {} for k,v in snapshot.items(): new_k = 'module.' + k ddp_snapshot[new_k] = v return ddp_snapshot