import torch from abc import ABC, abstractmethod class Loss(ABC): @abstractmethod def compute(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs) -> torch.FloatTensor: pass def __call__(self, anchor, sample, pos_mask=None, neg_mask=None, *args, **kwargs) -> torch.FloatTensor: loss = self.compute(anchor, sample, pos_mask, neg_mask, *args, **kwargs) return loss