# -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function import argparse import glob import logging import os import pickle import random import re import shutil import json import numpy as np import torch from transformers import RobertaConfig, RobertaModel, RobertaForSequenceClassification, RobertaForMaskedLM, RobertaTokenizer from model import Model logger = logging.getLogger(__name__) class codebert(object): def __init__(self, model_type, model_path, device, block_size=512, n_class=104): self.model_type = model_type if model_type == "clone": self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") else: self.tokenizer = RobertaTokenizer.from_pretrained(model_path) if model_type == "mlm": self.model = RobertaForMaskedLM.from_pretrained(model_path).to(device) elif model_type == "cls": self.model = RobertaForSequenceClassification.from_pretrained(model_path, num_labels=n_class).to(device) elif model_type == "clone": config = RobertaConfig.from_pretrained(model_path) encoder = RobertaModel(config) self.model = Model(encoder, config, self.tokenizer, block_size=block_size) self.model.load_state_dict(torch.load(os.path.join(model_path, "model.bin"), map_location="cpu")) self.model = self.model.to(device) self.block_size = block_size self.device = device def tokenize(self, inputs, cut_and_pad=True, ret_id=True): rets = [] if isinstance(inputs, str): inputs = [inputs] for sent in inputs: if cut_and_pad: tokens = self.tokenizer.tokenize(sent)[:self.block_size-2] tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] padding_length = self.block_size - len(tokens) tokens += [self.tokenizer.pad_token] * padding_length else: tokens = self.tokenizer.tokenize(sent) tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token] if not ret_id: rets.append(tokens) else: ids = self.tokenizer.convert_tokens_to_ids(tokens) rets.append(ids) return rets def _run_batch(self, batch): self.model.eval() batch_max_length = batch.ne(self.tokenizer.pad_token_id).sum(-1).max().item() inputs = batch[:, :batch_max_length] inputs = inputs.to(self.device) with torch.no_grad(): outputs = self.model(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id)) logits = outputs[0] return logits def run(self, inputs, batch_size=16): input_ids = self.tokenize(inputs, cut_and_pad=True, ret_id=True) outputs = None batch_num = (len(input_ids) - 1) // batch_size + 1 for step in range(batch_num): batch = torch.tensor(input_ids[step*batch_size: (step+1)*batch_size]) if outputs is None: outputs = self._run_batch(batch) else: _outputs = self._run_batch(batch) outputs = torch.cat([outputs, _outputs], 0) return outputs class codebert_mlm(codebert): def __init__(self, model_path, device): super().__init__("mlm", model_path, device) class codebert_cls(codebert): def __init__(self, model_path, device, n_class=104): super().__init__("cls", model_path, device, n_class=n_class) class codebert_clone(codebert): def __init__(self, model_path, device, block_size=400): super().__init__("clone", model_path, device, block_size) def _run_batch(self, batch): self.model.eval() inputs = batch.to(self.device) with torch.no_grad(): logits = self.model(inputs) return logits