# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import torch import torch.nn as nn import torch from torch.autograd import Variable import copy import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.out_proj = nn.Linear(config.hidden_size, 2) def forward(self, features, **kwargs): x = features[:, 0, :] # take token (equiv. to [CLS]) x = x.reshape(-1,x.size(-1)*2) x1,x2 = torch.split(x, x.size(-1)//2, dim=1) x = torch.abs(x1-x2) # x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.out_proj(x) return x class Model(nn.Module): def __init__(self, encoder,config,tokenizer,block_size): super(Model, self).__init__() self.encoder = encoder self.config=config self.tokenizer=tokenizer self.classifier=RobertaClassificationHead(config) self.block_size = block_size def forward(self, input_ids=None,labels=None): input_ids=input_ids.view(-1,self.block_size) outputs = self.encoder(input_ids=input_ids,attention_mask=input_ids.ne(1))[0] logits=self.classifier(outputs) # prob=F.softmax(logits) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits, labels) return loss,logits else: return logits