DeepGPO / DeepGPO_multiple / multiple_code / Bertmodel.py
Bertmodel.py
Raw
from transformers.models.bert.modeling_bert import (BertEmbeddings, BertModel,
BertForSequenceClassification,BertPreTrainedModel,BertEncoder,BertPooler)
from transformers.models.roberta.modeling_roberta import RobertaModel,RobertaEncoder
import torch.nn as nn
import torch
import torch.nn.functional as F
from fastNLP.core.metrics import MetricBase,seq_len_to_mask
import GNN_global_representation
import GNN_edge_regression
from utils import *
#增加:糖基化结构用dgl的图结构表示(node_index,masses.glyco_process)--写GNN作为糖的全局表示

class Acid_BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""
    def __init__(self, config):
        super().__init__()
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.charge_embedding=nn.Embedding(10,config.hidden_size,padding_idx=0)
        self.a_embedding=nn.Embedding(30,config.hidden_size,padding_idx=0)
        self.phos_embedding=nn.Embedding(10,config.hidden_size)#修饰三种加上padding###全部调大了,其他的修饰也在这里
        print(f"GNN_global_ablation {GNN_global_ablation}!!!")
        if GNN_global_ablation=="GIN":
            self.gly_embedding=GNN_global_representation.GIN(20, GNN_global_hidden_dim, config.hidden_size,init_eps=0)
            #16也可以改
        if GNN_global_ablation=="GCN":
            self.gly_embedding=GNN_global_representation.GCN(20, GNN_global_hidden_dim, config.hidden_size)
            # print("input GCN model")
        if GNN_global_ablation=="GAT":
            self.gly_embedding=GNN_global_representation.GAT(20, GNN_global_hidden_dim, config.hidden_size,num_heads=4)
        if GNN_global_ablation=="Nogly":
            pass
        if GNN_global_ablation not in ["GIN","GCN","GAT","Nogly"]:
            raise NameError
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

        self.config=config
    def forward(
        self, peptide_tokens=None, position_ids=None,decoration=None,charge=None,batched_graph=None, inputs_embeds=None, past_key_values_length=0
    ):  
        N=peptide_tokens.size(0)
        L=peptide_tokens.size(1)
        sequence=peptide_tokens
        charge_embed=self.charge_embedding(charge.unsqueeze(1).expand(N,L))
        feat=batched_graph.ndata["attr"]
        if GNN_global_ablation=="GIN" or GNN_global_ablation =="GCN" or GNN_global_ablation =="GAT":
            gly_embedding=self.gly_embedding(batched_graph,feat)
            mask = decoration == 5
            num_fives = mask.sum().item()
            positions = mask.nonzero(as_tuple=False)  # [[i1, j1], [i2, j2], ..., [in, jn]]

            # 创建一个全零的 output
            output = torch.zeros(N, L, self.config.hidden_size).to(gly_embedding.device)

            # 分解位置
            row_idx, col_idx = positions[:, 0], positions[:, 1]

            # 批量赋值:output[row_idx, col_idx] = gly_embedding[:num_fives]
            output[row_idx, col_idx] = gly_embedding[:num_fives]
            gly_embedding=output
        if GNN_global_ablation=="Nogly":
            pass
        assert sequence.size(0) == decoration.size(0)
        # ipdb.set_trace()
        #decoration传入了decoration_ids,包括了decoration_ACE的信息,见preprocess.py
        phos_embed = self.phos_embedding(decoration)
        
        if loc==False:
            phos_embed = self.phos_embedding(decoration-5*(decoration==5).to(int))

        if peptide_tokens is not None:
            input_shape = peptide_tokens.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        inputs_embeds = self.a_embedding(peptide_tokens)
        if GNN_global_ablation !="Nogly":
            embeddings = inputs_embeds+phos_embed+charge_embed+gly_embedding
        if GNN_global_ablation =="Nogly":
            embeddings = inputs_embeds+phos_embed+charge_embed

        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

# ----------------------- byBY ------------------------------#
class ModelbyBYms2_bert(BertModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    def __init__(self,config):
        super().__init__(config)
        self.config = config
        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=num_col
        self.mslinear=nn.Linear(config.hidden_size,num_col)
        print(f"GNN_edge_ablation {GNN_edge_ablation}!!!")
        if GNN_edge_ablation=="GIN":
            self.BY_pred=GNN_edge_regression.GIN(20, GNN_edge_hidden_dim, config.hidden_size,init_eps=0)
        if GNN_edge_ablation=="GCN":
            self.BY_pred=GNN_edge_regression.GCN(20, GNN_edge_hidden_dim, config.hidden_size)
        if GNN_edge_ablation=="GAT":
            self.BY_pred=GNN_edge_regression.GAT(20, GNN_edge_hidden_dim, config.hidden_size,num_heads=4)
        self.peptide_rep_linear=nn.Linear(config.hidden_size,GNN_edge_hidden_dim)
        self.pooler = BertPooler(config)
        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,
                charge,decoration,decoration_ids,_id,decoration_ACE,PlausibleStruct,
                peptide,graph_edges,ions_by_p,ions_BY_p,return_dict=None,head_mask=None):
        #input:input_ids:N*2L(N:batch)
        # length:N*1(N:batch)
        # phos:N*2L(N:batch)
        #charge:N*1
        
        # ninput=self.dropout(ninput)
        # import ipdb
        # ipdb.set_trace()
        key_padding_mask=seq_len_to_mask(peptide_length*2)
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        # past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        batched_graph=PlausibleStruct
        # import sys
        # sys.path.append('..')
        from masses import glyco_process
        import dgl
        graph_list=[]
        # import ipdb
        # ipdb.set_trace()
        node2idx={"P":0,"H" : 2, "N" : 1, "A" : 4, "G" : 5 , "F" :3}
        for i in batched_graph:
            struct_list = i.split(";")
            for struct in struct_list:
                _,_,Struct_tokens,nodef=glyco_process(struct)
                nodef="P"+nodef
                Struct_feature=[node2idx[i] for i in nodef]
                grapbatched_graphh=dgl.graph(Struct_tokens)
                grapbatched_graphh.ndata["attr"]=torch.Tensor(Struct_feature).to(int)
                u , v = grapbatched_graphh.edges()
                grapbatched_graphh.add_edges(v , u) # bidirect
                grapbatched_graphh = grapbatched_graphh.add_self_loop()
                graph_list.append(grapbatched_graphh)
        batched_graph = dgl.batch(graph_list).to(device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, batched_graph=batched_graph,
        position_ids=None,
        decoration=decoration_ids,charge=charge, inputs_embeds=None, past_key_values_length=0
        )
        # import ipdb
        # ipdb.set_trace()
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            use_cache=False,
            return_dict=return_dict,
        )
        # import ipdb
        # ipdb.set_trace()
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        # pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        # outputmean=output.mean(dim=0)#N*E
        # outputrt=self.activation(self.rtlinear(outputmean))#N*1
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        #by
        outputms=self.dropout(self.mslinear(output))#N*(L-1)*num_col
        outputms=outputms[:,:,:self.num_col]
        outputms=self.activation(outputms)
        outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)
        masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        try:
            outputms=outputms.masked_fill(masks.eq(False), 0)
        except:
            import ipdb
            ipdb.set_trace()
        # print(outputms)
        # print(torch.sum(outputms))

        #BY
        # import ipdb
        # ipdb.set_trace()
        batched_graph=PlausibleStruct
        graph_list_BY=[]
        for i in batched_graph:
            struct_list = i.split(";")
            for struct in struct_list:
                _,_,Struct_tokens,nodef=glyco_process(struct)
                nodef="P"+nodef
                Struct_feature=[node2idx[i] for i in nodef]
                grapbatched_graphh=dgl.graph(Struct_tokens)
                grapbatched_graphh.ndata["attr"]=torch.Tensor(Struct_feature).to(int)
                graph_list_BY.append(grapbatched_graphh)
        # import ipdb
        # ipdb.set_trace()
        
        batched_graph = dgl.batch(graph_list_BY).to(device)
        feat = batched_graph.ndata["attr"]
        peptide_ind=torch.nonzero(feat==0).squeeze()
        peptide_rep=self.peptide_rep_linear(sequence_output[:,0,:])
        glycan_counts = [s.count(';') + 1 for s in PlausibleStruct]
        peptide_rep = torch.cat([peptide_rep[i].repeat(glycan_counts[i], 1) for i in range(len(glycan_counts))],dim=0)
        # import ipdb
        # ipdb.set_trace()
        #peptide_rep根据plausiblestruct的糖基数目复制一下,比如 2 1 2就复制2 1 2
        BY_pred=self.BY_pred(batched_graph,feat,peptide_rep,peptide_ind) #利用global_edge_regression对于图以及节点坐标得到边的碎裂可能性
        # ipdb.set_trace()
        pred_BY = torch.split(BY_pred, graph_edges.tolist())
        # ipdb.set_trace()
        ##padding BY
        contents=[BY.reshape(-1,1).squeeze() for BY in pred_BY]
        max_len = max(map(len, contents))
        tensor = torch.full((len(contents), max_len), fill_value=0,dtype=torch.float)
        for i, content_i in enumerate(contents):
            tensor[i, :len(content_i)] = content_i
        pred_BY=tensor.to(device)
        concatbyBY=torch.cat([outputms,pred_BY],dim=-1)#128,512 batchsize(padding_by_size*padding_BY_size)
        # ipdb.set_trace()
        #target
        target_BY = torch.split(ions_BY_p, graph_edges.tolist())
        # ipdb.set_trace()
        ##padding BY
        contents=[BY.reshape(-1,1).squeeze() for BY in target_BY]
        max_len = max(map(len, contents))
        tensor = torch.full((len(contents), max_len), fill_value=0,dtype=torch.float)
        for i, content_i in enumerate(contents):
            tensor[i, :len(content_i)] = content_i
        target_BY=tensor.to(device)
        target_byBY=torch.cat([ions_by_p,target_BY],dim=-1)#128,512 batchsize(padding_by_size*padding_BY_size)
        # import ipdb
        # ipdb.set_trace()
        return {'pred_by':outputms,'pred_BY':BY_pred,"pred":concatbyBY,
                'target_by':ions_by_p,'target_BY':ions_BY_p,"target":target_byBY,
                'sequence':peptide_tokens,'charge':charge,
                "decoration":decoration,"seq_len":peptide_length,"_id":_id,
                "PlausibleStruct":PlausibleStruct,"peptide":peptide,"graph_edges":graph_edges}


# ----------------------- by ------------------------------#
class _2deepchargeModelms2_bert(BertModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    def __init__(self,config):
        super().__init__(config)
        self.config = config
        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=num_col
        self.mslinear=nn.Linear(config.hidden_size,num_col)
        # model_ablation="DeepFLR" #DeepFLR
        # if model_ablation=="DeepFLR":
        #     self.mslinear=nn.Linear(config.hidden_size,36)
        self.pooler = BertPooler(config)
        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,charge,decoration,decoration_ids,_id,decoration_ACE,PlausibleStruct,return_dict=None,head_mask=None):
        #input:input_ids:N*2L(N:batch)
        # length:N*1(N:batch)
        # phos:N*2L(N:batch)
        #charge:N*1
        
        # ninput=self.dropout(ninput)
        # import ipdb
        # ipdb.set_trace()
        key_padding_mask=seq_len_to_mask(peptide_length*2)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        # past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        batched_graph=PlausibleStruct
        # import sys
        # sys.path.append('..')
        from masses import glyco_process
        import dgl
        graph_list=[]
        for i in batched_graph:
            _,_,Struct_tokens,nodef=glyco_process(i)
            node2idx={"P":0,"H" : 2, "N" : 1, "A" : 4, "G" : 5 , "F" :3}
            nodef="P"+nodef
            Struct_feature=[node2idx[i] for i in nodef]
            grapbatched_graphh=dgl.graph(Struct_tokens)
            grapbatched_graphh.ndata["attr"]=torch.Tensor(Struct_feature).to(int)
            graph_list.append(grapbatched_graphh)
        batched_graph = dgl.batch(graph_list).to(device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, batched_graph=batched_graph,
        position_ids=None,
        decoration=decoration_ids,charge=charge, inputs_embeds=None, past_key_values_length=0
        )
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            use_cache=False,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        # outputmean=output.mean(dim=0)#N*E
        # outputrt=self.activation(self.rtlinear(outputmean))#N*1
        import ipdb
        ipdb.set_trace()
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        
        outputms=self.dropout(self.mslinear(output))#N*(L-1)*num_col
        ##
        outputms=outputms[:,:,:self.num_col]
        outputms=self.activation(outputms)

        outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)

        masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        outputms=outputms.masked_fill(masks.eq(False), 0)

        # print(outputms)
        # print(torch.sum(outputms))
        return {'pred':outputms,'sequence':peptide_tokens,'charge':charge,"decoration":decoration,"seq_len":peptide_length,"_id":_id}


# ----------------------- rt ------------------------------#
class _2deepchargeModelms2_bert_irt(BertModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    def __init__(self,config):
        super().__init__(config)
        self.config = config
        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=num_col #有多少碎片
        # self.mslinear=nn.Linear(config.hidden_size,36)#这个应该不用了吧
        self.pooler = BertPooler(config)
        self.irtlinear=nn.Linear(config.hidden_size,1)
        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,charge,decoration,decoration_ids,_id,
    decoration_ACE,PlausibleStruct,return_dict=None,head_mask=None):
    #增加了_id,decoration_ACE,PlausibleStruct
        key_padding_mask=seq_len_to_mask(peptide_length*2)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        batched_graph=PlausibleStruct
        import sys
        sys.path.append('..')
        from masses import glyco_process
        import dgl
        graph_list=[]
        for i in batched_graph:
            _,_,Struct_tokens,nodef=glyco_process(i)
            node2idx={"P":0,"H" : 2, "N" : 1, "A" : 4, "G" : 5 , "F" :3}
            nodef="P"+nodef
            Struct_feature=[node2idx[i] for i in nodef]
            grapbatched_graphh=dgl.graph(Struct_tokens)
            grapbatched_graphh.ndata["attr"]=torch.Tensor(Struct_feature).to(int)
            graph_list.append(grapbatched_graphh)
        batched_graph = dgl.batch(graph_list).to(device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, batched_graph=batched_graph,
        position_ids=None,decoration=decoration_ids,charge=charge, inputs_embeds=None, 
        past_key_values_length=0
        )    
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            use_cache=False,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        predirt=self.irtlinear(pooled_output).squeeze()##
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        #outputms一定要有吗,弄一下尺寸,现在loss是nan
        # outputms=self.dropout(self.mslinear(output))#N*(L-1)*num_col
        # outputms=outputms[:,:,:self.num_col]
        # outputms=self.activation(outputms)
        # outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)
        # masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        # outputms=outputms.masked_fill(masks.eq(False), 0)
        return {'sequence':peptide_tokens,'charge':charge,"decoration":decoration,
        "seq_len":peptide_length,"_id":_id,"predirt":predirt}



# ----------------------- models not in use ------------------------------#              
class _2deepchargeModelms2_roberta(RobertaModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    _keys_to_ignore_on_load_missing = [r"position_ids"]
    def __init__(self,config):
        super().__init__(config)
        self.config = config

        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = RobertaEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=36
        self.mslinear=nn.Linear(config.hidden_size,36)
        self.pooler = BertPooler(config)

        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,charge,decoration,decoration_ids,pnumber,return_dict=None,head_mask=None):
        #input:input_ids:N*2L(N:batch)
        # length:N*1(N:batch)
        # phos:N*2L(N:batch)
        #charge:N*1
        
        # ninput=self.dropout(ninput)
        key_padding_mask=seq_len_to_mask(peptide_length*2)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, 
        position_ids=None,
        decoration=decoration_ids,charge=charge, inputs_embeds=None, past_key_values_length=0

        )
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,

            use_cache=False,

            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        # outputmean=output.mean(dim=0)#N*E
        # outputrt=self.activation(self.rtlinear(outputmean))#N*1
        
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        
        outputms=self.dropout(self.mslinear(output))#N*(L-1)*num_col
        outputms=self.activation(outputms)
        outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)

        masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        outputms=outputms.masked_fill(masks.eq(False), 0)

        # print(outputms)
        # print(torch.sum(outputms))
        return {'pred':outputms,'sequence':peptide_tokens,'charge':charge,"decoration":decoration,"seq_len":peptide_length,
                'pnumber':pnumber}
class _2deepchargeModelms2_bert_ss(BertModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    def __init__(self,config):
        super().__init__(config)
        self.config = config

        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=36
        self.sslinear=nn.Linear(config.hidden_size,config.hidden_size)
        self.mslinear=nn.Linear(config.hidden_size,36)
        self.pooler = BertPooler(config)

        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,charge,decoration,decoration_ids,pnumber,return_dict=None,head_mask=None):
        #input:input_ids:N*2L(N:batch)
        # length:N*1(N:batch)
        # phos:N*2L(N:batch)
        #charge:N*1
        
        # ninput=self.dropout(ninput)
        key_padding_mask=seq_len_to_mask(peptide_length*2)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, 
        position_ids=None,
        decoration=decoration_ids,charge=charge, inputs_embeds=None, past_key_values_length=0

        )
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,

            use_cache=False,

            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        # outputmean=output.mean(dim=0)#N*E
        # outputrt=self.activation(self.rtlinear(outputmean))#N*1
        
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        outputss=self.sslinear(output)
        outputms=self.dropout(self.mslinear(outputss))#N*(L-1)*num_col

        outputms=self.activation(outputms)

        outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)

        masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        outputms=outputms.masked_fill(masks.eq(False), 0)

        # print(outputms)
        # print(torch.sum(outputms))
        return {'pred':outputms,'sequence':peptide_tokens,'charge':charge,"decoration":decoration,"seq_len":peptide_length,
                'pnumber':pnumber}
class _2deepchargeModelms2_bert_ss_contrast(BertModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    def __init__(self,config):
        super().__init__(config)
        self.config = config

        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=36
        self.sslinear=nn.Linear(config.hidden_size,config.hidden_size)
        self.mslinear=nn.Linear(config.hidden_size,36)
        self.pooler = BertPooler(config)

        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,charge,decoration,decoration_ids,pnumber,return_dict=None,head_mask=None):
        #input:input_ids:N*2L(N:batch)
        # length:N*1(N:batch)
        # phos:N*2L(N:batch)
        #charge:N*1
        
        # ninput=self.dropout(ninput)
        key_padding_mask=seq_len_to_mask(peptide_length*2)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, 
        position_ids=None,
        decoration=decoration_ids,charge=charge, inputs_embeds=None, past_key_values_length=0

        )
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,

            use_cache=False,

            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        # outputmean=output.mean(dim=0)#N*E
        # outputrt=self.activation(self.rtlinear(outputmean))#N*1
        
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        ############no_grad_no_linear_contrast, simsiam
        encoder_outputs_ss = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,

            use_cache=False,

            return_dict=return_dict,
        )
        sequence_output_ss=encoder_outputs_ss[0]
        output_nograd=sequence_output_ss[:,2:-1:2,:].detach()
        outputss=self.sslinear(output)
        ####对比outputss与output_nograd.size:B*(L-1)*E
        contrastmask=seq_len_to_mask(seq_len=(peptide_length - 1) *E )
        outputss_masked=outputss.reshape(batch_size,-1).masked_fill(contrastmask.eq(False), 0)
        output_nograd_masked=output_nograd.reshape(batch_size,-1).masked_fill(contrastmask.eq(False), 0)###B*((L-1)*E)
        p=F.normalize(outputss_masked,p=2,dim=1)
        z=F.normalize(output_nograd_masked,p=2,dim=1)
        sscontrastloss=-(p*z).sum(dim=1).mean()
        
        
        outputms=self.dropout(self.mslinear(outputss))#N*(L-1)*num_col

        outputms=self.activation(outputms)

        outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)

        masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        outputms=outputms.masked_fill(masks.eq(False), 0)

        # print(outputms)
        # print(torch.sum(outputms))
        return {"sscontrastloss":sscontrastloss,'pred':outputms,'sequence':peptide_tokens,'charge':charge,"decoration":decoration,"seq_len":peptide_length,
                'pnumber':pnumber}
class _2deepchargeModelms2_roberta_ss(RobertaModel):#input:sequence:N*L(N:batch)##########bertmodel
    #####输入是cls A sep B sep C ....所以长度为2L,最后取所有的sep出来预测因为pretrain的原因只能写死在这里面numcol
    def __init__(self,config):
        super().__init__(config)
        self.config = config

        self.embeddings = Acid_BertEmbeddings(config)
        self.encoder = RobertaEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.activation=nn.GELU()
        self.num_col=36
        self.sslinear=nn.Linear(config.hidden_size,config.hidden_size)
        self.mslinear=nn.Linear(config.hidden_size,36)
        self.pooler = BertPooler(config)

        self.init_weights()

    def forward(self,input_ids,peptide_tokens,peptide_length,charge,decoration,decoration_ids,pnumber,return_dict=None,head_mask=None):
        #input:input_ids:N*2L(N:batch)
        # length:N*1(N:batch)
        # phos:N*2L(N:batch)
        #charge:N*1
        
        # ninput=self.dropout(ninput)
        key_padding_mask=seq_len_to_mask(peptide_length*2)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        device = input_ids.device 
        past_key_values_length = 0
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(key_padding_mask, input_shape, device)
        embedding_output = self.embeddings(peptide_tokens=input_ids, 
        position_ids=None,
        decoration=decoration_ids,charge=charge, inputs_embeds=None, past_key_values_length=0

        )
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,

            use_cache=False,

            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]#####N*2L*E
        E=sequence_output.size(-1)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None####cls output可用来预测irt的
        # outputmean=output.mean(dim=0)#N*E
        # outputrt=self.activation(self.rtlinear(outputmean))#N*1
        
        output=sequence_output[:,2:-1:2,:]
        assert output.size(1)==int(seq_length/2-1)
        outputss=self.sslinear(output)
        outputms=self.dropout(self.mslinear(outputss))#N*(L-1)*num_col

        outputms=self.activation(outputms)

        outputms=outputms.reshape(batch_size,-1)#N*((L-1)*num_col)

        masks = seq_len_to_mask(seq_len=(peptide_length - 1) * self.num_col)##加上mask
        outputms=outputms.masked_fill(masks.eq(False), 0)

        # print(outputms)
        # print(torch.sum(outputms))
        return {'pred':outputms,'sequence':peptide_tokens,'charge':charge,"decoration":decoration,"seq_len":peptide_length,
                'pnumber':pnumber}