mvq / state / base_state.py
base_state.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# 
# Forked from:
# https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py
# ---------------------------------------------------------------


import os
import glob
import random

import torch
import numpy as np
import traceback

class State:

  def __init__(self, models, stage_keys, opts=None, schs=None, rank=0, mode='train', stage='vae'):

    self.epoch = -1
    self.global_step = 0
    self.rank = rank
    self.best_score = {'vae': float('inf'), 'seq': float('inf')}

    self.models = models
    self.opts = opts if opts is not None else {}
    self.schs = schs if schs is not None else {}
    self.scaler = {}
    self.metrics = []
    self.stage_keys = stage_keys
    self.all_stage_keys = [t for v in list(stage_keys.values()) for t in v]
    self.mode = mode
    self.cur_stage = stage


  def _capture_snapshot(self, stage):

    snap = {}
    if self.models.get(stage) is not None:
      snap['{}_model'.format(stage)] = self.models[stage].state_dict()

    if self.opts.get(stage) is not None:
      snap['{}_opt'.format(stage)] = self.opts[stage].state_dict()

    if self.schs.get(stage) is not None:
      snap['{}_schs'.format(stage)] = self.schs[stage].state_dict()
    
    if self.scaler.get(stage) is not None:
      snap['{}_scaler'.format(stage)] = self.scaler[stage].state_dict()
    
    return snap

  def capture_snapshot(self, stage=None):
    if stage is None:
      stages = self.all_stage_keys
    else:
      stages = [stage]

    snapshot = {
      'epoch': self.epoch,
      'global_step': self.global_step,
      'best_score': self.best_score, 
      'metrics': self.metrics,
      'rng': self.get_rng()
      }

    for stage in stages:
      snap = self._capture_snapshot(stage)
      snapshot = {**snapshot, **snap}

    return snapshot

  def _apply_snapshot(self, obj, stage):

    obj_name = '{}_model'.format(stage)
    if obj.get(obj_name) is not None and self.models.get(stage) is not None:
      ddp_ckpt = 'module.' in list(obj[obj_name].keys())[0]
      ddp_model = self.models[stage].__class__.__name__ == 'DistributedDataParallel'
      
      if ddp_ckpt and not ddp_model:
        obj[obj_name] = self.from_ddp_ckpt(obj[obj_name])
      elif not ddp_ckpt and ddp_model:
        obj[obj_name] = self.to_ddp_ckpt(obj[obj_name])

      self.models[stage].load_state_dict(obj[obj_name])

    obj_name = '{}_tokenizer'.format(stage)
    if obj.get(obj_name) is not None and self.tokenizers.get(stage) is not None:
      try:
        self.tokenizers[stage].load_state_dict(obj[obj_name])
      except:
        print("Failed loading tokenizer for {} ".format(stage))

    if stage == self.cur_stage:
      obj_name = '{}_opt'.format(stage)
      if obj.get(obj_name) is not None and self.opts.get(stage) is not None:
        try:
          self.opts[stage].load_state_dict(obj[obj_name])
        except:
          print("Failed loading opt for {} ".format(stage))

      obj_name = '{}_schs'.format(stage)
      if obj.get(obj_name) is not None and self.opts.get(stage) is not None:
        try:
          self.schs[stage].load_state_dict(obj[obj_name])
        except:
          print("Failed loading schs for {} ".format(stage))

      obj_name = '{}_scaler'.format(stage)
      if obj.get(obj_name) is not None and self.opts.get(stage) is not None:
        try:
          self.scaler[stage].load_state_dict(obj[obj_name])
        except:
          print("Failed loading schs for {} ".format(stage))


  def apply_snapshot(self, obj, stage=None):
    if stage is None:
      stages = self.all_stage_keys
    else:
      stages = [stage]

    for stage in stages:
      self._apply_snapshot(obj, stage)

  def get_rng(self):
    rng_states = {
        "python": random.getstate(),
        "numpy": np.random.get_state(),
        "cpu": torch.random.get_rng_state(),
    }

    if torch.cuda.is_available():
        rng_states["cuda"] = torch.cuda.random.get_rng_state('cuda').cpu()
    
    return rng_states
  
  def set_rng(self, rng):
    if rng is not None:
      random.setstate(rng["python"])
      np.random.set_state(rng["numpy"])
      try:
        #torch.random.set_rng_state(rng["cpu"])
        if torch.cuda.is_available():
          torch.cuda.random.set_rng_state(rng["cuda"].cpu())
      except:
        print("Failed to set random seed for torch")
        print(traceback.format_exc())
        pass



  def load(self, ckpt_dirs, device_id, mode):
    
    for model_key, ckpt_dir in ckpt_dirs.items():
      client_sd = None
      ckpt_files = glob.glob(ckpt_dir + '/*')

      if len(ckpt_files) > 0:
        torch_ckpts = []
        if model_key not in self.stage_keys[self.cur_stage] or mode == 'eval': 
          torch_ckpts = sorted([f for f in ckpt_files if f.endswith('snapshot.pth') and f.split('/')[-1].startswith('best')])
        
        if len(torch_ckpts) == 0:
          torch_ckpts = sorted([f for f in ckpt_files if f.endswith('snapshot.pth') and f.split('/')[-1].startswith('last')])

        if len(torch_ckpts) > 0:
          f = torch_ckpts[-1]

          if self.rank == 0: print("[stage={}] Loading torch ckpt: {}".format(model_key, f))
          client_sd = torch.load(f, map_location=f"cuda:{device_id}")
          self.apply_snapshot(client_sd, stage=model_key)


        if client_sd is not None and (client_sd['epoch'] > self.epoch or model_key == self.cur_stage):
          self.set_rng(client_sd.get('rng'))
          self.global_step = client_sd['global_step']
          self.epoch = client_sd['epoch']
          self.best_score = client_sd['best_score']
          self.metrics = client_sd['metrics']
        #else:
          #if self.rank == 0: print("client_sd is None! state has not been updated")
  
  def save(self, ckpt_dirs, tag=None, save_latest=True):

    for model_key, ckpt_dir in ckpt_dirs.items():
      if model_key in self.stage_keys[self.cur_stage]: 
        
        if tag is not None:
          ckpt_fn = '{}_snapshot.pth'.format(tag)
        else:
          ckpt_fn = '{}_snapshot.pth'.format(self.global_step)

        ckpt_path = os.path.join(ckpt_dir, ckpt_fn)
        torch.save(self.capture_snapshot(model_key), ckpt_path)

  def from_ddp_ckpt(self, ddp_snapshot):
    # Convert from ddp ckpt to nonddp model.
    snapshot = {}
    for k,v in ddp_snapshot.items():
      new_k = k.replace('module.','')
      snapshot[new_k] = v
    
    return snapshot
  
  def to_ddp_ckpt(self, snapshot):
    ddp_snapshot = {}
    for k,v in snapshot.items():
      new_k = 'module.' + k
      ddp_snapshot[new_k] = v
    return ddp_snapshot