CodeBERT-Attack / rnn / word2vec.py
word2vec.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Sun Nov  8 20:41:34 2020

@author: DrLC
"""

from gensim.models import Word2Vec
from transformers import RobertaTokenizer
import tqdm
import argparse
import pickle

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--pretrained', type=str,
                        default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986",
                        help="Path to the downstream OJ classifier")
    parser.add_argument('--trainset', type=str, default="../data/train.pkl",
                        help="Path to the train set")
    parser.add_argument('--embed_size', type=int, default=300, 
                        help="Width of the embedding vectors")
    parser.add_argument('--window_size', type=int, default=5,
                        help="Size of the sliding window")
    parser.add_argument('--w2v', type=str, default="../data/w2v.model",
                        help="Path to the pre-trained word2vec model")
    
    opt = parser.parse_args()
    
    tokenizer = RobertaTokenizer.from_pretrained(opt.pretrained)
    
    with open(opt.trainset, "rb") as f:
        data = pickle.load(f)
    sents = []
    for s in tqdm.tqdm(data['norm']):
        _s = []
        for t in s:
            _s.append(t.strip())
        sents.append(tokenizer.tokenize(" ".join(_s)))
        
    w2v = Word2Vec(sentences=sents,
                   size=opt.embed_size,
                   window=opt.window_size,
                   min_count=1,
                   workers=4)
    print (w2v.wv.vectors.shape)
    w2v.save(opt.w2v)
    
    w2v = Word2Vec.load(opt.w2v)
    print (w2v.wv.vectors.shape)
    print (w2v.wv.vocab.keys())