kglids / kg_governor / data_profiling / src / column_embeddings / utils.py
utils.py
Raw
import torch

EMBEDDING_SIZE = 300


def load_pretrained_model(model_class, model_path, embedding_size=EMBEDDING_SIZE):
    model = model_class(embedding_size)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()
    return model