# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import os import pickle import numpy as np import random import torch from easydict import EasyDict as edict import yaml def set_seeds(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def create_eval_dir(base_dir, tag, epoch=0): epoch_pred_dir = os.path.join(base_dir, '{:02d}_{}'.format(epoch, tag)) if not os.path.exists(epoch_pred_dir): try: os.mkdir(epoch_pred_dir) except FileExistsError: return epoch_pred_dir return epoch_pred_dir def pickle_save(item, fname): with open(fname, "wb") as f: pickle.dump(item, f) def mkdir_safe(path): if not os.path.exists(path): try: os.mkdir(path) except: pass def load_yaml(yaml_path): return edict(yaml.load(open(yaml_path, 'r'), Loader=yaml.FullLoader)) def load_cfg(cfg_name, cfg_dir, default_name='default.yaml'): default_path = os.path.join(cfg_dir, default_name) cfg_path = os.path.join(cfg_dir, cfg_name) assert os.path.exists(default_path) and os.path.exists(cfg_path) return merge_a_into_b(cfg_path, default_path) def merge_a_into_b(cfg_path, default_path): default = load_yaml(default_path) cfg = load_yaml(cfg_path) return _merge_a_into_b(cfg, default) def _merge_a_into_b(a, b): assert type(a) is edict and type(b) is edict for k, v in a.items(): # a must specify keys that are in b if k not in b: raise KeyError('{} is not a valid config key'.format(k)) # the types must match, too old_type = type(b[k]) if old_type is not type(v): if isinstance(b[k], np.ndarray): v = np.array(v, dtype=b[k].dtype) else: raise ValueError(('Type mismatch ({} vs. {}) ' 'for config key: {}').format(type(b[k]), type(v), k)) # recursively merge dicts if type(v) is edict: try: _merge_a_into_b(a[k], b[k]) except: print(('Error under config key: {}'.format(k))) raise else: b[k] = v return b