# -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function import argparse import glob import logging import os import pickle import random import re import shutil import json import numpy as np import torch from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer import torch.nn.functional as F from torch.nn import CrossEntropyLoss logger = logging.getLogger(__name__) class codegpt(object): def __init__(self, model_path, device, block_size=512): self.tokenizer = GPT2Tokenizer.from_pretrained(model_path) self.model = GPT2LMHeadModel.from_pretrained(model_path).to(device) self.block_size = block_size self.device = device def tokenize(self, inputs, cut_and_pad=False, ret_id=False): rets = [] if isinstance(inputs, str): inputs = [inputs] for sent in inputs: if cut_and_pad: tokens = self.tokenizer.tokenize(sent)[:self.block_size-2] tokens = [self.tokenizer.bos_token] + tokens + [self.tokenizer.eos_token] padding_length = self.block_size - len(tokens) tokens += [self.tokenizer.pad_token] * padding_length else: tokens = self.tokenizer.tokenize(sent) tokens = [self.tokenizer.bos_token] + tokens + [self.tokenizer.eos_token] if not ret_id: rets.append(tokens) else: ids = self.tokenizer.convert_tokens_to_ids(tokens) rets.append(ids) return rets def _run_batch(self, batch): self.model.eval() batch_max_length = batch.ne(self.tokenizer.pad_token_id).sum(-1).max().item() inputs = batch[:, :batch_max_length] batch_size = inputs.size(0) inputs = inputs.to(self.device) with torch.no_grad(): outputs = self.model(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id)) logits = outputs[0] shift_logits = logits[..., :-1, :].contiguous() shift_labels = inputs[..., 1:].contiguous() loss_fct = CrossEntropyLoss(reduction="none") loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(batch_size, -1) return torch.exp(torch.mean(loss, dim=-1)) def run(self, inputs, batch_size=16): input_ids = self.tokenize(inputs, cut_and_pad=True, ret_id=True) outputs = None batch_num = (len(input_ids) - 1) // batch_size + 1 for step in range(batch_num): batch = torch.tensor(input_ids[step*batch_size: (step+1)*batch_size]) if outputs is None: outputs = self._run_batch(batch) else: _outputs = self._run_batch(batch) outputs = torch.cat([outputs, _outputs], 0) return outputs