honeyplotnet / state / base_state.py
base_state.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# 
# Forked from:
# https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py
# ---------------------------------------------------------------

import os
import glob
import random

import torch
import numpy as np


class State:
  """
  Container for objects that we want to checkpoint. Represents the
  current "state" of the worker. This object is mutable.
  """

  def __init__(self, models, tokenizers, opts, schs, rank=0, mode='train', stage='caption'):
    self.epoch = -1
    self.global_step = 0
    self.best_score = {'chart_text': 0.0, 'continuous': float('inf'), 'seq':float('inf')}
    self.models = models
    self.tokenizers = tokenizers
    self.opts = opts
    self.schs = schs
    self.scaler = {}
    self.metrics = []
    self.snapshot_keys = ['chart_text', 'continuous', 'seq_base', 'seq_cond']
    self.mode = mode
    self.cur_stage = stage
    self.rank = rank

  def _capture_snapshot(self, stage):
    snap = {}
    if self.models.get(stage) is not None:
      snap['{}_model'.format(stage)] = self.models[stage].state_dict()

    if self.opts.get(stage) is not None:
      snap['{}_opt'.format(stage)] = self.opts[stage].state_dict()

    if self.schs.get(stage) is not None:
      snap['{}_schs'.format(stage)] = self.schs[stage].state_dict()
    
    if self.scaler.get(stage) is not None:
      snap['{}_scaler'.format(stage)] = self.scaler[stage].state_dict()
    return snap

  def capture_snapshot(self, stage=None):
    if stage is None:
      stages = self.snapshot_keys
    else:
      stages = [stage]

    snapshot = {
      'epoch': self.epoch,
      'global_step': self.global_step,
      'best_score': self.best_score, 
      'metrics': self.metrics,
      'rng': self.get_rng()
      }

    for stage in stages:
      snap = self._capture_snapshot(stage)
      snapshot = {**snapshot, **snap}

    return snapshot

  def _apply_snapshot(self, obj, stage):

    obj_name = '{}_model'.format(stage)
    if obj.get(obj_name) is not None and self.models.get(stage) is not None:
      ddp_ckpt = 'module.' in list(obj[obj_name].keys())[0]
      ddp_model = self.models[stage].__class__.__name__ == 'DistributedDataParallel'
      
      if ddp_ckpt and not ddp_model:
        obj[obj_name] = self.from_ddp_ckpt(obj[obj_name])
      elif not ddp_ckpt and ddp_model:
        obj[obj_name] = self.to_ddp_ckpt(obj[obj_name])

      self.models[stage].load_state_dict(obj[obj_name])

    obj_name = '{}_tokenizer'.format(stage)
    if obj.get(obj_name) is not None and self.tokenizers.get(stage) is not None:
      try:
        self.tokenizers[stage].load_state_dict(obj[obj_name])
      except:
        print("Failed loading tokenizer for {} ".format(stage))

    if stage == self.cur_stage:
      obj_name = '{}_opt'.format(stage)
      if obj.get(obj_name) is not None and self.opts.get(stage) is not None:
        try:
          self.opts[stage].load_state_dict(obj[obj_name])
        except:
          print("Failed loading opt for {} ".format(stage))

      obj_name = '{}_schs'.format(stage)
      if obj.get(obj_name) is not None and self.opts.get(stage) is not None:
        try:
          self.schs[stage].load_state_dict(obj[obj_name])
        except:
          print("Failed loading schs for {} ".format(stage))

      obj_name = '{}_scaler'.format(stage)
      if obj.get(obj_name) is not None and self.opts.get(stage) is not None:
        try:
          self.scaler[stage].load_state_dict(obj[obj_name])
        except:
          print("Failed loading schs for {} ".format(stage))

  def apply_snapshot(self, obj, stage=None):
    if stage is None:
      stages = self.snapshot_keys
    else:
      stages = [stage]

    for stage in stages:
      self._apply_snapshot(obj, stage)

  def get_rng(self):
    rng_states = {
        "python": random.getstate(),
        "numpy": np.random.get_state(),
        "cpu": torch.random.get_rng_state(),
    }

    if torch.cuda.is_available():
        rng_states["cuda"] = torch.cuda.random.get_rng_state('cuda').cpu()
    
    return rng_states
  
  def set_rng(self, rng):
    if rng is not None:
      random.setstate(rng["python"])
      np.random.set_state(rng["numpy"])
      try:
        torch.random.set_rng_state(rng["cpu"].cpu())
        if torch.cuda.is_available():
          torch.cuda.random.set_rng_state(rng["cuda"].cpu())
      except:
        print("Failed to set random seed for torch")
        pass

  def load(self, ckpt_dirs, device_id):
    device = 'cpu' if device_id == 'cpu' else f"cuda:{device_id}"
    
    for stage, ckpt_dir in ckpt_dirs.items():
      if self.models.get(stage) is None: continue

      client_sd = None
      ckpt_files = glob.glob(ckpt_dir + '/*')
      if self.rank == 0:
        print(f"Checkpoint Dir: Files=[{len(ckpt_files)}] Path={ckpt_dir}")

      if len(ckpt_files) > 0:
        zero2f32_ckpts = [f for f in ckpt_files if f.endswith('.bin')]
        torch_ckpts = sorted([f for f in ckpt_files if f.endswith('snapshot.pth')])

        if len(zero2f32_ckpts) > 0:
          f = zero2f32_ckpts[0]
          print("[stage={}] Loading zero2f32 ckpt: {}".format(stage, f))
          snapshot = torch.load(f, map_location=device)
          self.models[stage].load_state_dict(snapshot, strict=False)

        elif len(torch_ckpts) > 0:
          f = torch_ckpts[-1]

          print("[stage={}] Loading torch ckpt: {}".format(stage, f))
          client_sd = torch.load(f, map_location=device)
          self.apply_snapshot(client_sd, stage=stage)

        if client_sd is not None and client_sd.get('epoch') is not None and (client_sd.get('epoch') > self.epoch or stage == self.cur_stage):
          try:
            self.set_rng(client_sd.get('rng'))
            self.global_step = client_sd['global_step']
            self.epoch = client_sd['epoch']
            self.best_score = client_sd['best_score']
            self.metrics = client_sd['metrics']
          except:
            print("Failed to update client sd")
            pass
        else:
          print("client_sd is None! state has not been updated")
  

  def save(self, ckpt_dirs, tag=None):

    for stage, ckpt_dir in ckpt_dirs.items():
      if stage != self.cur_stage: continue

      if tag is not None:
        ckpt_fn = '{}_snapshot.pth'.format(tag)
      else:
        ckpt_fn = '{}_snapshot.pth'.format(self.global_step)

      ckpt_path = os.path.join(ckpt_dir, ckpt_fn)
      torch.save(self.capture_snapshot(stage), ckpt_path)

  def from_ddp_ckpt(self, ddp_snapshot):
    # Convert from ddp ckpt to nonddp model.
    snapshot = {}
    for k,v in ddp_snapshot.items():
      new_k = k.replace('module.','')
      snapshot[new_k] = v
    
    return snapshot
  
  def to_ddp_ckpt(self, snapshot):
    ddp_snapshot = {}
    for k,v in snapshot.items():
      new_k = 'module.' + k
      ddp_snapshot[new_k] = v
    return ddp_snapshot