# -*- coding: utf-8 -*- import torch from codebert import codebert_mlm, codebert_cls, codebert_clone from codegpt import codegpt import torch.nn.functional as F import argparse import os if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='-1') # microsoft/codebert-base-mlm parser.add_argument('--mlm_path', type=str, default="/var/data/lushuai/bertvsbert/save/with_rtd/poj/checkpoint-23000-1.0541", help="Path to the masked language model") parser.add_argument('--cls_path', type=str, default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986", help="Path to the OJ classifier") parser.add_argument('--clone_path', type=str, default="/var/data/lushuai/bertvsbert/save/poj-clone/old/rtd/checkpoint-best-f1", help="Path to the clone dectection on bigclonebench") parser.add_argument('--gpt_path', type=str, default="microsoft/CodeGPT-small-java-adaptedGPT2", help="Path to the codegpt") opt = parser.parse_args() if int(opt.gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu device = torch.device("cuda") mlm_model = codebert_mlm(opt.mlm_path, device) # cls_model = codebert_cls(opt.cls_path, device) clone_model = codebert_clone(opt.clone_path, device, block_size=400) # codegpt_model = codegpt(opt.gpt_path, device, block_size=512) inputs = [ "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }" ] def test_codebert(): clone_inputs = [ 'int main(){ cout << "Hello World"; return 0;}', 'void main() { ;}' ] print(clone_inputs[0]) print(clone_inputs[1]) # exit() true_inputs = [" ".join(x.split()) for x in clone_inputs] # print(true_inputs) # exit() # tokens = cls_model.tokenize(inputs) # print (tokens) # tokens = mlm_model.tokenize(inputs) # print (tokens) # pred = cls_model.run(inputs) # print (pred.size()) # print (tokens) # pred = mlm_model.run(inputs, len(inputs)) # print (pred.size()) # print (tokens) # tokens = clone_model.tokenize(clone_inputs) # print(tokens) # tokens = mlm_model.tokenize(clone_inputs) # print(tokens) pred = clone_model.run(true_inputs) print(pred) prob=F.softmax(pred) print(prob) # pred = mlm_model.run(clone_inputs, len(clone_inputs)) # print(pred.size()) def test_codegpt(): inputs = [ ' public String[][] getProjecttreeData() {\n String[][] treeData = null;\n String filename = dms_home + FS + "temp" + FS + username + "adminprojects.xml";\n String urlString = dms_url + "/servlet/com.ufnasoft.dms.server.ServerGetAdminProjects";\n try {\n String urldata = urlString + "?username=" + URLEncoder.encode(username, "UTF-8") + "&key=" + URLEncoder.encode(key, "UTF-8") + "&filename=" + URLEncoder.encode(username, "UTF-8") + "adminprojects.xml";\n DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();\n factory.setValidating(false);\n DocumentBuilder parser = factory.newDocumentBuilder();\n URL u = new URL(urldata);\n DataInputStream is = new DataInputStream(u.openStream());\n FileOutputStream os = new FileOutputStream(filename);\n int iBufSize = is.available();\n byte inBuf[] = new byte[20000 * 1024];\n int iNumRead;\n while ((iNumRead = is.read(inBuf, 0, iBufSize)) > 0) os.write(inBuf, 0, iNumRead);\n os.close();\n is.close();\n File f = new File(filename);\n InputStream inputstream = new FileInputStream(f);\n Document document = parser.parse(inputstream);\n NodeList nodelist = document.getElementsByTagName("proj");\n int num = nodelist.getLength();\n treeData = new String[num][3];\n for (int i = 0; i < num; i++) {\n treeData[i][0] = new String(DOMUtil.getSimpleElementText((Element) nodelist.item(i), "pid"));\n treeData[i][1] = new String(DOMUtil.getSimpleElementText((Element) nodelist.item(i), "ppid"));\n treeData[i][2] = new String(DOMUtil.getSimpleElementText((Element) nodelist.item(i), "p"));\n }\n } catch (MalformedURLException ex) {\n System.out.println(ex);\n } catch (ParserConfigurationException ex) {\n System.out.println(ex);\n } catch (NullPointerException e) {\n } catch (Exception ex) {\n System.out.println(ex);\n }\n return treeData;\n }\n', ' private static void copy(String sourceName, String destName) throws IOException {\n File source = new File(sourceName);\n File dest = new File(destName);\n FileChannel in = null, out = null;\n try {\n in = new FileInputStream(source).getChannel();\n out = new FileOutputStream(dest).getChannel();\n long size = in.size();\n MappedByteBuffer buf = in.map(FileChannel.MapMode.READ_ONLY, 0, size);\n out.write(buf);\n } finally {\n if (in != null) in.close();\n if (out != null) out.close();\n }\n }\n', ] inputs = [" ".join(x.split()) for x in inputs] for x in inputs: ppl = codegpt_model.run(x) print(ppl) # test_codegpt() test_codebert()