CodeBERT-Attack / oj-attack / codegpt.py
codegpt.py
Raw
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import pickle
import random
import re
import shutil
import json

import numpy as np
import torch
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

logger = logging.getLogger(__name__)

class codegpt(object):
    def __init__(self, model_path, device, block_size=512):
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        self.model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
        self.block_size = block_size
        self.device = device

    def tokenize(self, inputs, cut_and_pad=False, ret_id=False):
        rets = []
        if isinstance(inputs, str):
            inputs = [inputs]
        for sent in inputs:
            if cut_and_pad:
                tokens = self.tokenizer.tokenize(sent)[:self.block_size-2]
                tokens = [self.tokenizer.bos_token] + tokens + [self.tokenizer.eos_token]
                padding_length = self.block_size - len(tokens)
                tokens += [self.tokenizer.pad_token] * padding_length
            else:
                tokens = self.tokenizer.tokenize(sent)
                tokens = [self.tokenizer.bos_token] + tokens + [self.tokenizer.eos_token]
            if not ret_id:
                rets.append(tokens)
            else:
                ids = self.tokenizer.convert_tokens_to_ids(tokens)
                rets.append(ids)
        return rets

    def _run_batch(self, batch):
        self.model.eval()
        batch_max_length = batch.ne(self.tokenizer.pad_token_id).sum(-1).max().item()
        inputs = batch[:, :batch_max_length]
        batch_size = inputs.size(0)
        inputs = inputs.to(self.device)
        with torch.no_grad():
            outputs = self.model(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id))
            logits = outputs[0]
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = inputs[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss(reduction="none")
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(batch_size, -1)
        return torch.exp(torch.mean(loss, dim=-1))
    
    def run(self, inputs, batch_size=16):
        input_ids = self.tokenize(inputs, cut_and_pad=True, ret_id=True)
        outputs = None
        batch_num = (len(input_ids) - 1) // batch_size + 1
        for step in range(batch_num):
            batch = torch.tensor(input_ids[step*batch_size: (step+1)*batch_size])
            if outputs is None:
                outputs = self._run_batch(batch)
            else:
                _outputs = self._run_batch(batch)
                outputs = torch.cat([outputs, _outputs], 0)
        return outputs