# -*- coding: utf-8 -*- """ Created on Sun Nov 8 20:41:34 2020 @author: DrLC """ from gensim.models import Word2Vec from transformers import RobertaTokenizer import tqdm import argparse import pickle if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--pretrained', type=str, default="/var/data/lushuai/bertvsbert/save/poj-classifier/checkpoint-51000-0.986", help="Path to the downstream OJ classifier") parser.add_argument('--trainset', type=str, default="../data/train.pkl", help="Path to the train set") parser.add_argument('--embed_size', type=int, default=300, help="Width of the embedding vectors") parser.add_argument('--window_size', type=int, default=5, help="Size of the sliding window") parser.add_argument('--w2v', type=str, default="../data/w2v.model", help="Path to the pre-trained word2vec model") opt = parser.parse_args() tokenizer = RobertaTokenizer.from_pretrained(opt.pretrained) with open(opt.trainset, "rb") as f: data = pickle.load(f) sents = [] for s in tqdm.tqdm(data['norm']): _s = [] for t in s: _s.append(t.strip()) sents.append(tokenizer.tokenize(" ".join(_s))) w2v = Word2Vec(sentences=sents, size=opt.embed_size, window=opt.window_size, min_count=1, workers=4) print (w2v.wv.vectors.shape) w2v.save(opt.w2v) w2v = Word2Vec.load(opt.w2v) print (w2v.wv.vectors.shape) print (w2v.wv.vocab.keys())