DeepGPO / DeepGPO_multiple / multiple_code / MSELoss_for_byBY.py
MSELoss_for_byBY.py
Raw
from fastNLP import LossBase
import torch.nn.functional as F
import torch
class MSELoss_byBY(LossBase):
    r"""
    MSE损失函数

    :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
    :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
    :param str reduction: 支持'mean''sum''none'.

    """
########################adding loss weights
    def __init__(self, pred_by=None,pred_BY=None,target_by=None,target_BY=None, reduction='mean'):
        super(MSELoss_byBY, self).__init__()
        self._init_param_map(pred_by=pred_by,pred_BY=pred_BY,target_by=target_by,target_BY=target_BY,weights="weights",graph_edges="graph_edges")
        assert reduction in ('mean', 'sum', 'none')
        self.reduction = reduction
    def get_weighted_loss(self, pred_by,pred_BY, target_by,target_BY,graph_edges=None,weights=None):
        # pred_BY = torch.split(pred_BY, graph_edges.tolist())
        # contents=[BY.reshape(-1,1).squeeze() for BY in pred_BY]
        # max_len = max(map(len, contents))
        # tensor = torch.full((len(contents), max_len), fill_value=0,dtype=torch.float)
        # for i, content_i in enumerate(contents):
        #     tensor[i, :len(content_i)] = content_i
        # pred_BY=tensor.to(weights.device)

        # target_BY = torch.split(target_BY, graph_edges.tolist())
        # contents=[BY.reshape(-1,1).squeeze() for BY in target_BY]
        # max_len = max(map(len, contents))
        # tensor = torch.full((len(contents), max_len), fill_value=0,dtype=torch.float)
        # for i, content_i in enumerate(contents):
        #     tensor[i, :len(content_i)] = content_i
        # target_BY=tensor.to(weights.device)
########################adding loss weights
        byloss=F.mse_loss(input=pred_by.float(), target=target_by.float(), reduction='none').mean(dim=-1)
        BYloss=F.mse_loss(input=pred_BY.float(), target=target_BY.float(), reduction='none').mean(dim=-1)
        BY_weights=torch.repeat_interleave(weights, graph_edges)
        #BYloss=F.mse_loss(input=pred_BY.float(), target=target_BY.float(), reduction='none').mean(dim=-1)直接用小by的同样方式效果更好
        #BYloss=weights*BYloss
        byloss=weights*byloss
        BYloss=BY_weights*BYloss
        if self.reduction=="mean":
            byloss=torch.mean(byloss,dim=-1)
            # BYloss=(BYloss*weights).sum()/graph_edges.sum()
            BYloss=torch.mean(BYloss,dim=-1)
        elif self.reduction=="sum":
            byloss=torch.sum(byloss,dim=-1)
            BYloss=torch.sum(BYloss,dim=-1)
            # BYloss=(BYloss*weights).sum()
        else:#none
            pass
        return byloss,BYloss
    def get_loss(self, pred_by,pred_BY, target_by,target_BY,graph_edges=None,weights=None):

        if weights is not None:
            # import ipdb
            # ipdb.set_trace()
            byloss,BYloss=self.get_weighted_loss(pred_by,pred_BY, target_by,target_BY,graph_edges,weights)

        else:
            byloss=F.mse_loss(input=pred_by.float(), target=target_by.float(), reduction=self.reduction)
            BYloss=F.mse_loss(input=pred_BY.float(), target=target_BY.float(), reduction=self.reduction)
            # byloss=F.mse_loss(input=torch.log2(pred_by.float()+1), target=torch.log2(target_by.float()+1), reduction=self.reduction)
            # import ipdb
            # ipdb.set_trace()
            # print("loss ratio",BYloss/byloss)
        loss=50*byloss+BYloss

        # loss=byloss+0.02*BYloss
        return loss