# -*- coding: utf-8 -*- """ Created on Sun Nov 8 20:41:34 2020 @author: DrLC """ from gensim.models import Word2Vec from transformers import RobertaTokenizer import tqdm import argparse import json if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--pretrained', type=str, default="/var/data/zhanghz/codebert-base-mlm", help="Path to the downstream clone detector") parser.add_argument('--data_path', type=str, default="../data/bcb_data.jsonl") parser.add_argument('--trainset', type=str, default="../data/bcb_train.txt", help="Path to the train set") parser.add_argument('--embed_size', type=int, default=300, help="Width of the embedding vectors") parser.add_argument('--window_size', type=int, default=5, help="Size of the sliding window") parser.add_argument('--w2v', type=str, default="../data/bcb_w2v.model", help="Path to the pre-trained word2vec model") opt = parser.parse_args() tokenizer = RobertaTokenizer.from_pretrained(opt.pretrained) with open(opt.data_path, "r") as f: src = {} for l in f.readlines(): r = json.loads(l) src[r['idx']] = r['func'] with open(opt.trainset, "r") as f: d = f.readlines() d = [i.strip().split() for i in d] src_idx = [] for i, i2, _ in tqdm.tqdm(d): if i not in src_idx: src_idx.append(i) if i2 not in src_idx: src_idx.append(i2) sents = [] for i in tqdm.tqdm(src_idx): _s = [] for t in src[i]: _s.append(t.strip()) sents.append(tokenizer.tokenize(" ".join(_s))) w2v = Word2Vec(sentences=sents, size=opt.embed_size, window=opt.window_size, min_count=1, workers=4) print (w2v.wv.vectors.shape) w2v.save(opt.w2v) w2v = Word2Vec.load(opt.w2v) print (w2v.wv.vectors.shape) print (w2v.wv.vocab.keys())