honeyplotnet / utils / dist_helpers.py
dist_helpers.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import os
import sys
import torch
import numpy as np
import torch.distributed as dist
import logging
import time
from tensorboardX import SummaryWriter
from datetime import timedelta


def launch_dist_backend(dist_cfg, timeout=1800, debug=False):
  if dist_cfg.use:
    try:
        dist.init_process_group(
            backend=dist_cfg.backend, 
            init_method=dist_cfg.init_method, 
            timeout=timedelta(seconds=timeout)
            )
    except ValueError:
        dist_cfg.use = False
        print("Initialising Pytorch Distributed failed. It might not be available - switching to non-distributed mode.")


def average_gradients(model):
    errors = 0
    size = float(dist.get_world_size())
    for name, param in model.named_parameters():
        # if param.grad is None and param.requires_grad:
        #     print("Unused Gradient >> Name: {}".format(name))
        #     #print(param.grad.data.shape)
        #     raise
        
        if param.grad is None and 'perceptual_loss' not in name:
            print("Unused Gradient >> Name: {} [{}]".format(name, param.requires_grad))
            errors += 1

        if param.requires_grad:
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            param.grad.data /= size
    
    if errors >= 0:
        print("Error Count: {}".format(errors))

def average_tensor(t):
    size = float(dist.get_world_size())
    dist.all_reduce(t.data, op=dist.ReduceOp.SUM)
    t.data /= size

class Logger(object):
    def __init__(self, rank, save):
        # other libraries may set logging before arriving at this line.
        # by reloading logging, we can get rid of previous configs set by other libraries.
        from importlib import reload
        reload(logging)
        self.rank = rank
        if self.rank == 0:
            log_format = '%(asctime)s %(message)s'
            logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                                format=log_format, datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(save, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            self.start_time = time.time()

    def info(self, string, *args):
        if self.rank == 0:
            elapsed_time = time.time() - self.start_time
            elapsed_time = time.strftime(
                '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time))
            if isinstance(string, str):
                string = elapsed_time + string
            else:
                logging.info(elapsed_time)
            logging.info(string, *args)


class Writer(object):
    def __init__(self, rank, save):
        self.rank = rank
        if self.rank == 0:
            self.writer = SummaryWriter(log_dir=save, flush_secs=20)

    def add_scalar(self, *args, **kwargs):
        if self.rank == 0:
            self.writer.add_scalar(*args, **kwargs)

    def add_figure(self, *args, **kwargs):
        if self.rank == 0:
            self.writer.add_figure(*args, **kwargs)

    def add_image(self, *args, **kwargs):
        if self.rank == 0:
            self.writer.add_image(*args, **kwargs)

    def add_histogram(self, *args, **kwargs):
        if self.rank == 0:
            self.writer.add_histogram(*args, **kwargs)

    def add_histogram_if(self, write, *args, **kwargs):
        if write and False:   # Used for debugging.
            self.add_histogram(*args, **kwargs)

    def close(self, *args, **kwargs):
        if self.rank == 0:
            self.writer.close()