mvq / models / manage.py
manage.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Forked from:
# https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py
# ---------------------------------------------------------------

import os
import io
import numpy
import torch
import torch.distributed as dist
import shutil
from contextlib import contextmanager

@contextmanager
def tmp_process_group(backend):
    cpu_pg = dist.new_group(backend=backend)
    try:
        yield cpu_pg
    finally:
        dist.destroy_process_group(cpu_pg)

def load_checkpoint(state, ckpt_dirs, device_id, rank, distributed, mode='train'):


    if any(os.path.exists(c) for c in ckpt_dirs.values()):
        state.load(ckpt_dirs, device_id, mode=mode)
    else:
        if rank == 0: print("N[{}/{}] Checkpoint does not exist.".format(device_id, rank))

    if not distributed:
        return state

    # logic below is unnecessary when the checkpoint is visible on all nodes!
    # create a temporary cpu pg to broadcast most up-to-date checkpoint
    with tmp_process_group(backend="gloo") as pg:
        rank = dist.get_rank(group=pg)

        # get rank that has the largest state.epoch
        epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32)
        epochs[rank] = state.epoch
        dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg)
        t_max_epoch, t_max_rank = torch.max(epochs, dim=0)
        max_epoch = t_max_epoch.item()
        max_rank = t_max_rank.item()

        # max_epoch == -1 means no one has checkpointed return base state
        if max_epoch == -1:
            print("N[{}/{}] No Ckpt Found".format(device_id, rank))

            return state

        # broadcast the state from max_rank (which has the most up-to-date state)
        # pickle the snapshot, convert it into a byte-blob tensor
        # then broadcast it, unpickle it and apply the snapshot
        print("N[{}/{}] Restore Rank: {}, Epoch: {}".format(
            device_id, rank, max_rank, max_epoch))

        with io.BytesIO() as f:
            torch.save(state.capture_snapshot(), f)
            raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8).copy()

        blob_len = torch.tensor(len(raw_blob))
        dist.broadcast(blob_len, src=max_rank, group=pg)
        #print("N[{}/{}] Broadcast Size {}".format(device_id, rank, blob_len))

        if rank != max_rank:
            blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
        else:

            blob = torch.as_tensor(raw_blob, dtype=torch.uint8)
            #blob = torch.as_tensor(numpy.array(raw_blob), dtype=torch.uint8)

        dist.broadcast(blob, src=max_rank, group=pg)
        #print("N[{}/{}] Broadcast Complete".format(device_id, rank))

        if rank != max_rank:
            with io.BytesIO(blob.numpy()) as f:
                snapshot = torch.load(f)
            state.apply_snapshot(snapshot)

        # wait till everyone has loaded the checkpoint
        dist.barrier(group=pg)

    #print("N[{}/{}] Ckpt Restore Complete".format(device_id, rank))
    return state


def save_checkpoint(state, is_best=False, checkpoint_file=None, save_dir=None, epoch=None, stage=None):

    if checkpoint_file is None and save_dir is None:
        raise ValueError("Need to provide either checkpoint file or save dir")
    
    if checkpoint_file is not None:
        filename = checkpoint_file
    else:
        if save_dir is None and epoch is None:
            raise ValueError("Need to provide both save dir and epoch")
        else:
            filename = os.path.join(save_dir, 'snapshot_{:03d}.pth'.format(epoch))

    # save to tmp, then commit by moving the file in case the job
    # gets interrupted while writing the checkpoint
    tmp_filename = filename + ".tmp"
    torch.save(state.capture_snapshot(), tmp_filename)
    os.rename(tmp_filename, filename)
    #print(f"=> saved checkpoint for epoch {state.epoch} at {filename}")
    if is_best:
        if save_dir is None:
            save_dir = '/'.join(checkpoint_file.split('/')[:-1])
        
        best_name = "model_best.pth.tar" if stage is None else "{}_model_best.pth.tar".format(stage)
        best = os.path.join(save_dir, best_name)
        #print(f"=> best model found at epoch {state.epoch} saving to {best}")
        shutil.copyfile(filename, best)