from collections import Counter, Mapping from concurrent.futures import ProcessPoolExecutor import logging from multiprocessing import cpu_count from six import string_types from gensim.models import Word2Vec from gensim.models.word2vec import Vocab logger = logging.getLogger("deepwalk") class Skipgram(Word2Vec): """A subclass to allow more customization of the Word2Vec internals.""" def __init__(self, vocabulary_counts=None, **kwargs): self.vocabulary_counts = None kwargs["min_count"] = kwargs.get("min_count", 0) kwargs["workers"] = kwargs.get("workers", cpu_count()) kwargs["size"] = kwargs.get("size", 128) kwargs["sentences"] = kwargs.get("sentences", None) kwargs["window"] = kwargs.get("window", 10) kwargs["sg"] = 1 kwargs["hs"] = 1 if vocabulary_counts != None: self.vocabulary_counts = vocabulary_counts super(Skipgram, self).__init__(**kwargs)