CodeBERT-Attack / rnn / bcb_word2vec.py
bcb_word2vec.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Sun Nov  8 20:41:34 2020

@author: DrLC
"""

from gensim.models import Word2Vec
from transformers import RobertaTokenizer
import tqdm
import argparse
import json

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--pretrained', type=str,
                        default="/var/data/zhanghz/codebert-base-mlm",
                        help="Path to the downstream clone detector")
    parser.add_argument('--data_path', type=str, default="../data/bcb_data.jsonl")
    parser.add_argument('--trainset', type=str, default="../data/bcb_train.txt",
                        help="Path to the train set")
    parser.add_argument('--embed_size', type=int, default=300, 
                        help="Width of the embedding vectors")
    parser.add_argument('--window_size', type=int, default=5,
                        help="Size of the sliding window")
    parser.add_argument('--w2v', type=str, default="../data/bcb_w2v.model",
                        help="Path to the pre-trained word2vec model")
    
    opt = parser.parse_args()
    
    tokenizer = RobertaTokenizer.from_pretrained(opt.pretrained)
    
    with open(opt.data_path, "r") as f:
        src = {}
        for l in f.readlines():
            r = json.loads(l)
            src[r['idx']] = r['func']
    with open(opt.trainset, "r") as f:
        d = f.readlines()
        d = [i.strip().split() for i in d]
        src_idx = []
        for i, i2, _ in tqdm.tqdm(d):
            if i not in src_idx:
                src_idx.append(i)
            if i2 not in src_idx:
                src_idx.append(i2)

    sents = []
    for i in tqdm.tqdm(src_idx):
        _s = []
        for t in src[i]:
            _s.append(t.strip())
        sents.append(tokenizer.tokenize(" ".join(_s)))
        
    w2v = Word2Vec(sentences=sents,
                   size=opt.embed_size,
                   window=opt.window_size,
                   min_count=1,
                   workers=4)
    print (w2v.wv.vectors.shape)
    w2v.save(opt.w2v)
    
    w2v = Word2Vec.load(opt.w2v)
    print (w2v.wv.vectors.shape)
    print (w2v.wv.vocab.keys())