CodeBERT-Attack / oj-attack / model.py
model.py
Raw
# Copyright (c) Microsoft Corporation. 
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch
from torch.autograd import Variable
import copy
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss

class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, 2)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = x.reshape(-1,x.size(-1)*2)
        x1,x2 = torch.split(x, x.size(-1)//2, dim=1)
        x = torch.abs(x1-x2)
        # x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x
        
class Model(nn.Module):   
    def __init__(self, encoder,config,tokenizer,block_size):
        super(Model, self).__init__()
        self.encoder = encoder
        self.config=config
        self.tokenizer=tokenizer
        self.classifier=RobertaClassificationHead(config)
        self.block_size = block_size
    
        
    def forward(self, input_ids=None,labels=None): 
        input_ids=input_ids.view(-1,self.block_size)
        outputs = self.encoder(input_ids=input_ids,attention_mask=input_ids.ne(1))[0]
        logits=self.classifier(outputs)
        # prob=F.softmax(logits)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss,logits
        else:
            return logits