mvq / utils / gen_helper.py
gen_helper.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import os
import pickle
import numpy as np
import random
import torch
from easydict import EasyDict as edict
import yaml

def set_seeds(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

def create_eval_dir(base_dir, tag, epoch=0):
  epoch_pred_dir = os.path.join(base_dir, '{:02d}_{}'.format(epoch, tag))
  if not os.path.exists(epoch_pred_dir):
    try:
        os.mkdir(epoch_pred_dir)
    except FileExistsError:
        return epoch_pred_dir
  return epoch_pred_dir

def pickle_save(item, fname):
  with open(fname, "wb") as f:
    pickle.dump(item, f)

def mkdir_safe(path):
  if not os.path.exists(path):
    try:
      os.mkdir(path)
    except:
      pass

def load_yaml(yaml_path):
  return edict(yaml.load(open(yaml_path, 'r'), Loader=yaml.FullLoader))

def load_cfg(cfg_name, cfg_dir, default_name='default.yaml'):
  default_path = os.path.join(cfg_dir, default_name)
  cfg_path = os.path.join(cfg_dir, cfg_name)
  assert os.path.exists(default_path) and os.path.exists(cfg_path)
  return merge_a_into_b(cfg_path, default_path)

def merge_a_into_b(cfg_path, default_path):
  default = load_yaml(default_path)
  cfg = load_yaml(cfg_path)
  return _merge_a_into_b(cfg, default)
  
def _merge_a_into_b(a, b):

  assert type(a) is edict and type(b) is edict

  for k, v in a.items():
    # a must specify keys that are in b
    if k not in b:
      raise KeyError('{} is not a valid config key'.format(k))

    # the types must match, too
    old_type = type(b[k])
    if old_type is not type(v):
      if isinstance(b[k], np.ndarray):
        v = np.array(v, dtype=b[k].dtype)
      else:
        raise ValueError(('Type mismatch ({} vs. {}) '
                          'for config key: {}').format(type(b[k]),
                                                       type(v), k))

    # recursively merge dicts
    if type(v) is edict:
      try:
        _merge_a_into_b(a[k], b[k])
      except:
        print(('Error under config key: {}'.format(k)))
        raise
    else:
      b[k] = v 
  
  return b