CodeBERT-Attack / codebert / run_demo.py
run_demo.py
Raw
# -*- coding: utf-8 -*-

import torch
from codebert import codebert_mlm, codebert_cls

import argparse
import os

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', type=str, default='-1')
    parser.add_argument('--mlm_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj/checkpoint-9000-1.0555",
                        help="Path to the masked language model")
    parser.add_argument('--cls_path', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986",
                        help="Path to the OJ classifier")
    
    opt = parser.parse_args()

    if int(opt.gpu) < 0:
        device = torch.device("cpu")
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
        device = torch.device("cuda")
    mlm_model = codebert_mlm(opt.mlm_path, device)
    cls_model = codebert_cls(opt.cls_path, device)
    
    inputs = [
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }",
        "int main ( ) { int n , i ; n = 1 ; return 0 }",
        "int main ( ) { int <mask>, i ; <mask> = 1 ; return 0 }", 
        "void main ( ) { double x ; }",
        "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }"
    ]
    
    tokens = cls_model.tokenize(inputs)
    print (tokens)
    tokens = mlm_model.tokenize(inputs)
    print (tokens)
    pred = cls_model.run(inputs)
    print (pred.size())
    print (tokens)
    pred = mlm_model.run(inputs, len(inputs))
    print (pred.size())
    print (tokens)