import torch import numpy as np class early_stopping(object): """ Early stopping class. Default parameters: mode='min' lower value of metric is better (e.g. for loss) min_delta=0 margin by which metric has to improve patience=10 number of epochs to wait for improvement percentage=False inidcates if min_delta is absolute or relative """ def __init__(self, mode='min', min_delta=0, patience=10, percentage=False): self.mode = mode self.min_delta = min_delta self.patience = patience self.best = None self.num_bad_epochs = 0 self.is_better = None self._init_is_better(mode, min_delta, percentage) if patience == 0: self.is_better = lambda a, b: True self.step = lambda a: False def step(self, metrics): if self.best is None: self.best = metrics return False if np.isnan(metrics): return True if self.is_better(metrics, self.best): self.num_bad_epochs = 0 self.best = metrics else: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: return True return False def _init_is_better(self, mode, min_delta, percentage): if mode not in {'min', 'max'}: raise ValueError('mode ' + mode + ' is unknown!') if not percentage: if mode == 'min': self.is_better = lambda a, best: a < best - min_delta if mode == 'max': self.is_better = lambda a, best: a > best + min_delta else: if mode == 'min': self.is_better = lambda a, best: a < best - ( best * min_delta / 100) if mode == 'max': self.is_better = lambda a, best: a > best + ( best * min_delta / 100) def curr_is_better(self, metrics): return self.is_better(metrics, self.best)