import collections from typing import Dict, List, Optional, NamedTuple import transformers.trainer_seq2seq from transformers.trainer_utils import PredictionOutput, speed_metrics from datasets.arrow_dataset import Dataset from datasets.metric import Metric import numpy as np import time class EvalPrediction(NamedTuple): predictions: List[str] label_ids: np.ndarray metas: List[dict] class Seq2SeqTrainer(transformers.trainer_seq2seq.Seq2SeqTrainer): def __init__( self, metric: Metric, *args, eval_examples: Optional[Dataset] = None, ignore_pad_token_for_loss: bool = True, target_with_db_id: bool = False, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.metric = metric self.eval_examples = eval_examples self.compute_metrics = self._compute_metrics self.ignore_pad_token_for_loss = ignore_pad_token_for_loss self.target_with_db_id = target_with_db_id def _compute_metrics(self, eval_prediction: EvalPrediction) -> dict: raise NotImplementedError() def _post_process_function( self, examples: Dataset, features: Dataset, predictions: np.ndarray, stage: str ) -> EvalPrediction: raise NotImplementedError() def evaluate( self, eval_dataset: Optional[Dataset] = None, eval_examples: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", max_length: Optional[int] = None, max_time: Optional[int] = None, num_beams: Optional[int] = None, ) -> Dict[str, float]: self._max_length = max_length self._max_time = max_time self._num_beams = num_beams # memory metrics - must set up as early as possible self._memory_tracker.start() eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples start_time = time.time() # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None try: output: PredictionOutput = self.evaluation_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics # We might have removed columns from the dataset so we put them back. if isinstance(eval_dataset, Dataset): eval_dataset.set_format( type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys()), ) if eval_examples is not None and eval_dataset is not None and self.compute_metrics is not None: eval_preds = self._post_process_function( eval_examples, eval_dataset, output.predictions, "eval_{}".format(self.state.epoch), ) output.metrics.update(self.compute_metrics(eval_preds)) n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset) output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples)) # Prefix all keys with metric_key_prefix + '_' for key in list(output.metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key) self.log(output.metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return output.metrics def predict( self, test_dataset: Dataset, test_examples: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", max_length: Optional[int] = None, max_time: Optional[int] = None, num_beams: Optional[int] = None, ) -> PredictionOutput: self._max_length = max_length self._max_time = max_time self._num_beams = num_beams # memory metrics - must set up as early as possible self._memory_tracker.start() if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized): raise ValueError("test_dataset must implement __len__") test_dataloader = self.get_test_dataloader(test_dataset) start_time = time.time() # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None try: output: PredictionOutput = self.evaluation_loop( test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics if self.compute_metrics is not None: # We might have removed columns from the dataset so we put them back. if isinstance(test_dataset, Dataset): test_dataset.set_format( type=test_dataset.format["type"], columns=list(test_dataset.features.keys()), ) eval_preds = self._post_process_function( test_examples, test_dataset, output.predictions, metric_key_prefix) output.metrics.update(self.compute_metrics(eval_preds)) output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset))) # Prefix all keys with metric_key_prefix + '_' for key in list(output.metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key) self.log(output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return output