honeyplotnet / utils / ksm_scores.py
ksm_scores.py
Raw
import torch
import yake


SIM_PAIRS = [
    ('all_text_kw', 'captions_kw'),
    ('all_text_kw', 'axis_words'),
    ('all_text_kw', 'categorical_words'),
    ('all_text_kw', 'series_words'),
    ('all_text_kw', 'categ_and_series_words'),
    ('captions_kw', 'axis_words'),
    ('captions_kw', 'categorical_words'),
    ('captions_kw', 'series_words'),
    ('captions_kw', 'categ_and_series_words'),
    ('series_words', 'categorical_words')
]

def yake_text(text_str, max_ngram_size=3, dedup_thres=0.9, num_kw=20):
    kw_extractor = yake.KeywordExtractor(
        n=max_ngram_size, dedupLim=dedup_thres, top=num_kw)
    return kw_extractor.extract_keywords(text_str)
    
def tokenize(tokenizer, text, max_source_len=1024):
    inputs = tokenizer(
        text, max_length=max_source_len, 
        padding="max_length", truncation=True, return_tensors="pt")
    return inputs

def embed_and_encode(tokens, model, device):
    with torch.no_grad():
        embeddings = model.get_input_embeddings()(tokens.input_ids.to(device))
        encoder = model.get_encoder()
        output = encoder(inputs_embeds=embeddings, attention_mask=tokens.attention_mask.to(device)).last_hidden_state
    return output

def tokenize_keywords(d, tokenizer, model, max_source_len, device):
    extract_list = ['captions_kw','all_text_kw','categorical_words','series_words','axis_words']
    embeddings = {}
    for name in extract_list:
        if len(d[name]):
            tok = tokenize(tokenizer, d[name], max_source_len=max_source_len)
            emb = embed_and_encode(tok, model, device)
            #Average through middle dimension as per https://arxiv.org/pdf/2108.08877.pdf
            emb = emb.mean(1)
            embeddings[name] = emb
    
    #Create combine cateogircal and series if exist
    if len(d['categorical_words']) and len(d['series_words']):
        combined = d['categorical_words'] + d['series_words']
        tok = tokenize(tokenizer,combined, max_source_len=max_source_len)
        emb = embed_and_encode(tok, model, device)
        
        #Average through middle dimension as per https://arxiv.org/pdf/2108.08877.pdf
        emb = emb.mean(1)
        embeddings['categ_and_series_words'] = emb

    return embeddings

def calc_similarity(emb1, emb2, dim=-1, eps=1e-6):
    '''Calculates between all combinations of pairs'''
    cos = torch.nn.CosineSimilarity(dim=dim, eps=eps)
    
    if emb1.size(0) > emb2.size(0):
        repeat_emb, base_emb = emb1, emb2
        bsz = emb2.size(0)
    else:
        repeat_emb, base_emb = emb2, emb1
        bsz = emb1.size(0)

    similarity_matrix = []
    for bidx in range(bsz):
        repeated_emb = repeat_emb[bidx,:].unsqueeze(0).repeat(bsz, 1)
        sim = cos(repeated_emb, base_emb)
        similarity_matrix.append(sim)
    similarity_matrix = torch.stack(similarity_matrix, dim=0)
    return similarity_matrix


def calc_similarity_between_all_pairs(embeddings, sim_pairs, to_cpu=True):
    '''
    Computes similarity between pairs of text
    '''
    
    sim_pair_dict = dict()
    for pair1, pair2 in sim_pairs:
        if embeddings.get(pair1) is not None and embeddings.get(pair2) is not None:
            emb1 = embeddings[pair1]
            emb2 = embeddings[pair2]
            sim_score = calc_similarity(emb1, emb2)
            sim_score = sim_score.mean()
            if to_cpu:
                sim_score = sim_score.detach().cpu().item()

            sim_pair_dict[(pair1, pair2)] = sim_score

    return sim_pair_dict

def compute_similarity_scores_for_dataset(sim_pairs, dataset, tokenizer, model, device, max_source_len=128):
    # Max source len can be low because its only keywords
    model = model.to(device)

    sim_container = {}
    for d in dataset:
        embeddings = tokenize_keywords(d, tokenizer, model, max_source_len, device)
        similarity_scores = calc_similarity_between_all_pairs(embeddings, sim_pairs)

        for sim_pair, sim_score in similarity_scores.items():
            if sim_pair not in sim_container:
                sim_container[sim_pair] = []
            sim_container[sim_pair] += [sim_score]
    
    #Take average across each pair
    for sim_pair in list(sim_container.keys()):
        sim_container[sim_pair] = (sum(sim_container[sim_pair]) / len(sim_container[sim_pair]), len(sim_container[sim_pair]))

    return sim_container