from fastNLP import LossBase import torch.nn.functional as F import torch class MSELoss_byBY(LossBase): r""" MSE损失函数 :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` :param str reduction: 支持'mean','sum'和'none'. """ ########################adding loss weights def __init__(self, pred_by=None,pred_BY=None,target_by=None,target_BY=None, reduction='mean'): super(MSELoss_byBY, self).__init__() self._init_param_map(pred_by=pred_by,pred_BY=pred_BY,target_by=target_by,target_BY=target_BY,weights="weights",graph_edges="graph_edges") assert reduction in ('mean', 'sum', 'none') self.reduction = reduction def get_weighted_loss(self, pred_by,pred_BY, target_by,target_BY,graph_edges=None,weights=None): # pred_BY = torch.split(pred_BY, graph_edges.tolist()) # contents=[BY.reshape(-1,1).squeeze() for BY in pred_BY] # max_len = max(map(len, contents)) # tensor = torch.full((len(contents), max_len), fill_value=0,dtype=torch.float) # for i, content_i in enumerate(contents): # tensor[i, :len(content_i)] = content_i # pred_BY=tensor.to(weights.device) # target_BY = torch.split(target_BY, graph_edges.tolist()) # contents=[BY.reshape(-1,1).squeeze() for BY in target_BY] # max_len = max(map(len, contents)) # tensor = torch.full((len(contents), max_len), fill_value=0,dtype=torch.float) # for i, content_i in enumerate(contents): # tensor[i, :len(content_i)] = content_i # target_BY=tensor.to(weights.device) ########################adding loss weights byloss=F.mse_loss(input=pred_by.float(), target=target_by.float(), reduction='none').mean(dim=-1) BYloss=F.mse_loss(input=pred_BY.float(), target=target_BY.float(), reduction='none').mean(dim=-1) BY_weights=torch.repeat_interleave(weights, graph_edges) #BYloss=F.mse_loss(input=pred_BY.float(), target=target_BY.float(), reduction='none').mean(dim=-1)直接用小by的同样方式效果更好 #BYloss=weights*BYloss byloss=weights*byloss BYloss=BY_weights*BYloss if self.reduction=="mean": byloss=torch.mean(byloss,dim=-1) # BYloss=(BYloss*weights).sum()/graph_edges.sum() BYloss=torch.mean(BYloss,dim=-1) elif self.reduction=="sum": byloss=torch.sum(byloss,dim=-1) BYloss=torch.sum(BYloss,dim=-1) # BYloss=(BYloss*weights).sum() else:#none pass return byloss,BYloss def get_loss(self, pred_by,pred_BY, target_by,target_BY,graph_edges=None,weights=None): if weights is not None: # import ipdb # ipdb.set_trace() byloss,BYloss=self.get_weighted_loss(pred_by,pred_BY, target_by,target_BY,graph_edges,weights) else: byloss=F.mse_loss(input=pred_by.float(), target=target_by.float(), reduction=self.reduction) BYloss=F.mse_loss(input=pred_BY.float(), target=target_BY.float(), reduction=self.reduction) # byloss=F.mse_loss(input=torch.log2(pred_by.float()+1), target=torch.log2(target_by.float()+1), reduction=self.reduction) # import ipdb # ipdb.set_trace() # print("loss ratio",BYloss/byloss) loss=50*byloss+BYloss # loss=byloss+0.02*BYloss return loss