# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- class ResultTracker(object): def __init__(self, intervals, print_names=None): self.intervals = intervals # ['epoch', 'iter', ''] self.print_names = print_names self.metric_names = [] # ['loss', 'count',] self.reset_all() def reset_all(self): self.metrics = {} for inter in self.intervals: self.metrics[inter] = {} for name in self.metric_names: self.metrics[inter][name] = [] def reset_interval(self, interval): for name in self.metric_names: self.metrics[interval][name] = [] def add(self, interval, metric, val): if not isinstance(interval, list): interval = [interval] for inter in interval: if metric not in self.metrics[inter]: self.metrics[inter][metric] = [] if metric not in self.metric_names: self.metric_names.append(metric) if isinstance(val, list): self.metrics[inter][metric] += val else: self.metrics[inter][metric].append(val) def mean(self, interval, metric): count_metric = 'count' if metric in self.metrics[interval] and sum(self.metrics[interval][count_metric]): return sum(self.metrics[interval][metric]) / sum(self.metrics[interval][count_metric]) else: return 0.0 def get_sum(self, interval, metric): return sum(self.metrics[interval][metric]) def get_len(self, interval, metric): if metric not in self.metrics[interval]: return 0 return len(self.metrics[interval][metric]) def get_loss(self, interval, loss_name): if self.get_len(interval, loss_name) > 0: return self.get_sum(interval, loss_name) / self.get_len(interval, loss_name) return 0 def loss_str(self, interval): string = 'L: ' for n in sorted(list(set(self.metric_names))): if '_loss' in n and self.get_len(interval, n): short_name = n.split('/')[-1].replace('_loss','') if short_name not in self.print_names: continue #if '_loss' in n and self.get_len(interval, n): ref = n.split('/')[-1].split('_')[0].upper() string += '{}={:2.3f} '.format( ref, self.get_sum(interval, n) / self.get_len(interval, n), ) #print(n, self.get_len(interval, n)) #print(self.metric_names) #raise return string