# --------------------------------------------------------------- # Copyright (c) __________________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import os import json import sys import torch import contextlib from packaging import version from numpy import ceil from collections.abc import Mapping import torch import torch.distributed as dist from transformers import get_scheduler from utils import ( Logger, Writer, ResultTracker, ) from transformers.trainer_pt_utils import ( distributed_concat ) if version.parse(torch.__version__) >= version.parse("1.6"): _is_torch_generator_available = True _is_native_amp_available = True from torch.cuda.amp import autocast class BaseRunner(object): def __init__(self, cfg): self.epoch = 0 self.global_step = 1 self.work_env = cfg.work_env self.cfg = cfg self.debug = cfg.debug self.device_id = self.local_rank() self.device = f'cuda:{self.local_rank()}' if cfg.device_id != 'cpu' else 'cpu' self.use_fid = cfg.eval.fid self.fid_stats = None self.use_torch_dist = cfg.torch_dist.use self.display = cfg.train.intervals.display self.bsz = self.cfg.batch_size self.logger = Logger(self.rank(), cfg.save_dir) self.writer = Writer(self.rank(), cfg.save_dir) self.logger.info("Runner Initialized - Rank=[{}/{}]".format(self.local_rank(), self.rank())) self.metrics = [] self.metric_names = ['scale', 'continuous', 'categorical','series_name', 'cb1', 'cb2', 'wta','ct','row','col'] self.print_names = ['scale', 'continuous', 'categorical','series_name', 'cb1', 'cb2', 'wta','ct','row','col'] self.tracker = ResultTracker(['epoch', 'iter'], print_names=self.print_names) self.lr_scheduler = None self.scaler = None self.gradient_accum_steps = cfg.train.gradient_accum_steps self.max_grad_norm = cfg.train.max_grad_norm self.use_amp = False self.do_grad_scaling = False if self.cfg.fp16.use: self.use_amp = True self.amp_dtype = torch.float16 if self.cfg.fp16.use else torch.bfloat16 self.do_grad_scaling = True self.scaler = torch.cuda.amp.GradScaler() def local_rank(self): r = os.environ.get("LOCAL_RANK") r = 0 if r is None else int(r) return r def rank(self): r = os.environ.get("RANK") r = 0 if r is None else int(r) return r def update_writer(self, split, interval='epoch'): for n in sorted(list(set(self.tracker.metric_names))): if self.tracker.get_loss(interval, n): self.writer.add_scalar(n, self.tracker.get_loss(interval, n), self.epoch) def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): """ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. Args: num_training_steps (int): The number of training steps to do. """ if self.lr_scheduler is None: self.lr_scheduler = get_scheduler( 'linear', #self.cfg.train.scheduler.type, optimizer=optimizer, num_warmup_steps=int(ceil(num_training_steps * self.cfg.train.scheduler.warmup_ratio)), num_training_steps=num_training_steps, ) return self.lr_scheduler def autocast_smart_context_manager(self): """ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired arguments, depending on the situation. """ if self.use_amp: if version.parse(torch.__version__) >= version.parse("1.10"): ctx_manager = autocast(dtype=self.amp_dtype) else: ctx_manager = autocast() else: ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() return ctx_manager def _prepare_input(self, data): """ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. """ if isinstance(data, Mapping): return type(data)({k: self._prepare_input(v) for k, v in data.items()}) elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = dict(device=self.device) return data.to(**kwargs) return data def _prepare_inputs(self, inputs): """ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and handling potential state. """ inputs = self._prepare_input(inputs) if len(inputs) == 0: raise ValueError( "The batch received was empty, your model won't be able to train on it. Double-check that your " f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." ) return inputs def _nested_gather(self, tensors, name=None): """ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before concatenating them to `gathered` """ if tensors is None: return elif self.use_torch_dist: tensors = distributed_concat(tensors) return tensors # Copied from Accelerate. def _pad_across_processes(self, tensor, pad_index=-100): """ Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they can safely be gathered. """ if isinstance(tensor, (list, tuple)): return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) elif isinstance(tensor, dict): return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) elif not isinstance(tensor, torch.Tensor): raise TypeError( f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." ) if len(tensor.shape) < 2: return tensor # Gather all sizes size = torch.tensor(tensor.shape, device=tensor.device)[None] sizes = self._nested_gather(size).cpu() max_size = max(s[1] for s in sizes) if tensor.shape[1] == max_size: return tensor # Then pad to the maximum size old_size = tensor.shape new_size = list(old_size) new_size[1] = max_size new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index new_tensor[:, : old_size[1]] = tensor return new_tensor def to_json(self, samples, prefix, step=0, epoch=0): # Save list of dicts assert isinstance(samples, list), "must be list" if len(samples): assert isinstance(samples[0], dict), "must be list of dicts" epoch_dir = os.path.join(self.cfg.sample_dirs[self.stage], "{}".format(epoch)) if not os.path.exists(epoch_dir): os.makedirs(epoch_dir, exist_ok=True) for idx, sample in enumerate(samples): filename = os.path.join(epoch_dir, "{}-{}-{}.json".format(prefix, step, idx)) with open(filename, "w") as f: json.dump(sample, f) def to_vega_json(self, samples, prefix, step=0, epoch=0): # Save list of dicts assert isinstance(samples, list), "must be list" if len(samples): assert isinstance(samples[0], dict), "must be list of dicts" epoch_dir = os.path.join(self.cfg.sample_dirs[self.stage], "{}".format(epoch)) if not os.path.exists(epoch_dir): os.makedirs(epoch_dir, exist_ok=True) for idx, sample in enumerate(samples): chart_type = sample['chart_type'] assert chart_type in ['point','categorical','boxplot'], chart_type if chart_type == 'categorical': json_file = self.build_categorical_json(sample) elif chart_type == 'point': json_file = self.build_point_json(sample) else: continue filename = os.path.join(epoch_dir, f"{step}-{idx}-{chart_type}.json") with open(filename, 'w') as f: json.dump(json_file, f)