CodeBERT-Attack / oj-attack / codebert.py
codebert.py
Raw
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil
import json

import numpy as np
import torch
from transformers import RobertaConfig, RobertaModel, RobertaForSequenceClassification, RobertaForMaskedLM, RobertaTokenizer
from model import Model

logger = logging.getLogger(__name__)

class codebert(object):
    def __init__(self, model_type, model_path, device, block_size=512, n_class=104):
        self.model_type = model_type
        if model_type == "clone":
            self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        else:
            self.tokenizer = RobertaTokenizer.from_pretrained(model_path)
        if model_type == "mlm":
            self.model = RobertaForMaskedLM.from_pretrained(model_path).to(device)
        elif model_type == "cls":
            self.model = RobertaForSequenceClassification.from_pretrained(model_path, num_labels=n_class).to(device)
        elif model_type == "clone":
            config = RobertaConfig.from_pretrained(model_path)
            encoder = RobertaModel(config)
            self.model = Model(encoder, config, self.tokenizer, block_size=block_size)
            self.model.load_state_dict(torch.load(os.path.join(model_path, "model.bin"), map_location="cpu"))
            self.model = self.model.to(device)
        self.block_size = block_size
        self.device = device

    def tokenize(self, inputs, cut_and_pad=True, ret_id=True):
        rets = []
        if isinstance(inputs, str):
            inputs = [inputs]
        for sent in inputs:
            if cut_and_pad:
                tokens = self.tokenizer.tokenize(sent)[:self.block_size-2]
                tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
                padding_length = self.block_size - len(tokens)
                tokens += [self.tokenizer.pad_token] * padding_length
            else:
                tokens = self.tokenizer.tokenize(sent)
                tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
            if not ret_id:
                rets.append(tokens)
            else:
                ids = self.tokenizer.convert_tokens_to_ids(tokens)
                rets.append(ids)
        return rets

    def _run_batch(self, batch):
        self.model.eval()
        batch_max_length = batch.ne(self.tokenizer.pad_token_id).sum(-1).max().item()
        inputs = batch[:, :batch_max_length]
        inputs = inputs.to(self.device)
        with torch.no_grad():
            outputs = self.model(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id))
            logits = outputs[0]
        return logits
    
    def run(self, inputs, batch_size=16):
        input_ids = self.tokenize(inputs, cut_and_pad=True, ret_id=True)
        outputs = None
        batch_num = (len(input_ids) - 1) // batch_size + 1
        for step in range(batch_num):
            batch = torch.tensor(input_ids[step*batch_size: (step+1)*batch_size])
            if outputs is None:
                outputs = self._run_batch(batch)
            else:
                _outputs = self._run_batch(batch)
                outputs = torch.cat([outputs, _outputs], 0)
        return outputs


class codebert_mlm(codebert):
    def __init__(self, model_path, device):
        super().__init__("mlm", model_path, device)

class codebert_cls(codebert):
    def __init__(self, model_path, device, n_class=104):
        super().__init__("cls", model_path, device, n_class=n_class)

class codebert_clone(codebert):
    def __init__(self, model_path, device, block_size=400):
        super().__init__("clone", model_path, device, block_size)
    
    def _run_batch(self, batch):
        self.model.eval()
        inputs = batch.to(self.device)
        with torch.no_grad():
            logits = self.model(inputs)
        return logits