CodeBERT-Attack / oj-attack / run_demo.py
run_demo.py
Raw
# -*- coding: utf-8 -*-

import torch
from codebert import codebert_mlm, codebert_cls, codebert_clone
from codegpt import codegpt
import torch.nn.functional as F

import argparse
import os

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1')
    # microsoft/codebert-base-mlm
    parser.add_argument('--mlm_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/with_rtd/poj/checkpoint-23000-1.0541",
                        help="Path to the masked language model")
    parser.add_argument('--cls_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986",
                        help="Path to the OJ classifier")
    parser.add_argument('--clone_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj-clone/old/rtd/checkpoint-best-f1",
                        help="Path to the clone dectection on bigclonebench")
    parser.add_argument('--gpt_path', type=str,
                        default="microsoft/CodeGPT-small-java-adaptedGPT2",
                        help="Path to the codegpt")
    
    opt = parser.parse_args()

    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
        device = torch.device("cuda")
    mlm_model = codebert_mlm(opt.mlm_path, device)
    # cls_model = codebert_cls(opt.cls_path, device)
    clone_model = codebert_clone(opt.clone_path, device, block_size=400)
    # codegpt_model = codegpt(opt.gpt_path, device, block_size=512)
    
    inputs = [
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }"
    ]

    def test_codebert():
        clone_inputs = [
            'int main(){ cout << "Hello World"; return 0;}',
            'void main() { ;}'
        ]

        print(clone_inputs[0])
        print(clone_inputs[1])
        # exit()
        true_inputs = [" ".join(x.split()) for x in clone_inputs]
        # print(true_inputs)
        # exit()
        
        # tokens = cls_model.tokenize(inputs)
        # print (tokens)
        # tokens = mlm_model.tokenize(inputs)
        # print (tokens)
        # pred = cls_model.run(inputs)
        # print (pred.size())
        # print (tokens)
        # pred = mlm_model.run(inputs, len(inputs))
        # print (pred.size())
        # print (tokens)

        # tokens = clone_model.tokenize(clone_inputs)
        # print(tokens)
        # tokens = mlm_model.tokenize(clone_inputs)
        # print(tokens)
        pred = clone_model.run(true_inputs)
        print(pred)
        prob=F.softmax(pred)
        print(prob)
        # pred = mlm_model.run(clone_inputs, len(clone_inputs))
        # print(pred.size())

    def test_codegpt():
        inputs = [
            '    public String[][] getProjecttreeData() {\n        String[][] treeData = null;\n        String filename = dms_home + FS + "temp" + FS + username + "adminprojects.xml";\n        String urlString = dms_url + "/servlet/com.ufnasoft.dms.server.ServerGetAdminProjects";\n        try {\n            String urldata = urlString + "?username=" + URLEncoder.encode(username, "UTF-8") + "&key=" + URLEncoder.encode(key, "UTF-8") + "&filename=" + URLEncoder.encode(username, "UTF-8") + "adminprojects.xml";\n            DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();\n            factory.setValidating(false);\n            DocumentBuilder parser = factory.newDocumentBuilder();\n            URL u = new URL(urldata);\n            DataInputStream is = new DataInputStream(u.openStream());\n            FileOutputStream os = new FileOutputStream(filename);\n            int iBufSize = is.available();\n            byte inBuf[] = new byte[20000 * 1024];\n            int iNumRead;\n            while ((iNumRead = is.read(inBuf, 0, iBufSize)) > 0) os.write(inBuf, 0, iNumRead);\n            os.close();\n            is.close();\n            File f = new File(filename);\n            InputStream inputstream = new FileInputStream(f);\n            Document document = parser.parse(inputstream);\n            NodeList nodelist = document.getElementsByTagName("proj");\n            int num = nodelist.getLength();\n            treeData = new String[num][3];\n            for (int i = 0; i < num; i++) {\n                treeData[i][0] = new String(DOMUtil.getSimpleElementText((Element) nodelist.item(i), "pid"));\n                treeData[i][1] = new String(DOMUtil.getSimpleElementText((Element) nodelist.item(i), "ppid"));\n                treeData[i][2] = new String(DOMUtil.getSimpleElementText((Element) nodelist.item(i), "p"));\n            }\n        } catch (MalformedURLException ex) {\n            System.out.println(ex);\n        } catch (ParserConfigurationException ex) {\n            System.out.println(ex);\n        } catch (NullPointerException e) {\n        } catch (Exception ex) {\n            System.out.println(ex);\n        }\n        return treeData;\n    }\n',
            '    private static void copy(String sourceName, String destName) throws IOException {\n        File source = new File(sourceName);\n        File dest = new File(destName);\n        FileChannel in = null, out = null;\n        try {\n            in = new FileInputStream(source).getChannel();\n            out = new FileOutputStream(dest).getChannel();\n            long size = in.size();\n            MappedByteBuffer buf = in.map(FileChannel.MapMode.READ_ONLY, 0, size);\n            out.write(buf);\n        } finally {\n            if (in != null) in.close();\n            if (out != null) out.close();\n        }\n    }\n',
        ]
        inputs = [" ".join(x.split()) for x in inputs]
        for x in inputs:
            ppl = codegpt_model.run(x)
            print(ppl)

    # test_codegpt()
    test_codebert()