# -*- coding: utf-8 -*- import torch from codebert import codebert_mlm, codebert_cls import argparse import os if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=str, default='-1') parser.add_argument('--mlm_path', type=str, default="/var/data/lushuai/bertvsbert/save/poj/checkpoint-9000-1.0555", help="Path to the masked language model") parser.add_argument('--cls_path', type=str, default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986", help="Path to the OJ classifier") opt = parser.parse_args() if int(opt.gpu) < 0: device = torch.device("cpu") else: os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu device = torch.device("cuda") mlm_model = codebert_mlm(opt.mlm_path, device) cls_model = codebert_cls(opt.cls_path, device) inputs = [ "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }", "int main ( ) { int n , i ; n = 1 ; return 0 }", "int main ( ) { int , i ; = 1 ; return 0 }", "void main ( ) { double x ; }", "int main ( ) { int aVeryLongIntegerVar = 0 ; return aVeryLongIntegerVar ; }" ] tokens = cls_model.tokenize(inputs) print (tokens) tokens = mlm_model.tokenize(inputs) print (tokens) pred = cls_model.run(inputs) print (pred.size()) print (tokens) pred = mlm_model.run(inputs, len(inputs)) print (pred.size()) print (tokens)