KGTOSA / GNN-Methods / LinkPrediction / Morse / subgraph.py
subgraph.py
Raw
import pickle
import torch
import numpy as np
from collections import defaultdict as ddict
import lmdb
from tqdm import tqdm
import random
from utils import serialize, get_g, get_hr2t_rt2h_sup_que
import dgl


def gen_subgraph_datasets(args):
    print(f'-----There is no sub-graphs for {args.data_name}, so start generating sub-graphs before meta-training!-----')
    data = pickle.load(open(args.data_path, 'rb'))
    train_g = get_g(data['train_graph']['train'] + data['train_graph']['valid']
                    + data['train_graph']['test'])

    BYTES_PER_DATUM = get_average_subgraph_size(args, args.num_sample_for_estimate_size, train_g) * 2
    map_size = (args.num_train_subgraph + args.num_valid_subgraph) * BYTES_PER_DATUM
    env = lmdb.open(args.db_path, map_size=map_size, max_dbs=2)
    train_subgraphs_db = env.open_db("train_subgraphs".encode())
    valid_subgraphs_db = env.open_db("valid_subgraphs".encode())

    for idx in tqdm(range(args.num_train_subgraph)):
        str_id = '{:08}'.format(idx).encode('ascii')
        datum = sample_one_subgraph(args, train_g)
        with env.begin(write=True, db=train_subgraphs_db) as txn:
            txn.put(str_id, serialize(datum))

    for idx in tqdm(range(args.num_valid_subgraph)):
        str_id = '{:08}'.format(idx).encode('ascii')
        datum = sample_one_subgraph(args, train_g)
        with env.begin(write=True, db=valid_subgraphs_db) as txn:
            txn.put(str_id, serialize(datum))


def sample_one_subgraph(args, bg_train_g):
    # get graph with bi-direction
    bg_train_g_undir = dgl.graph((torch.cat([bg_train_g.edges()[0], bg_train_g.edges()[1]]),
                                  torch.cat([bg_train_g.edges()[1], bg_train_g.edges()[0]])))

    # induce sub-graph by sampled nodes
    while True:
        while True:
            sel_nodes = []
            for i in range(args.rw_0):
                if i == 0:
                    cand_nodes = np.arange(bg_train_g.num_nodes())
                else:
                    cand_nodes = sel_nodes
                rw, _ = dgl.sampling.random_walk(bg_train_g_undir,
                                                 np.random.choice(cand_nodes, 1, replace=False).repeat(args.rw_1),
                                                 length=args.rw_2)
                sel_nodes.extend(np.unique(rw.reshape(-1)))
                sel_nodes = list(np.unique(sel_nodes)) if -1 not in sel_nodes else list(np.unique(sel_nodes))[1:]
            sub_g = dgl.node_subgraph(bg_train_g, sel_nodes)

            if sub_g.num_nodes() >= 50:
                break

        sub_tri = torch.stack([sub_g.edges()[0],
                               sub_g.edata['rel'],
                               sub_g.edges()[1]])
        sub_tri = sub_tri.T.tolist()
        random.shuffle(sub_tri)

        ent_freq = ddict(int)
        rel_freq = ddict(int)
        triples_reidx = []
        ent_reidx = dict()
        entidx = 0
        for tri in sub_tri:
            h, r, t = tri
            if h not in ent_reidx.keys():
                ent_reidx[h] = entidx
                entidx += 1
            if t not in ent_reidx.keys():
                ent_reidx[t] = entidx
                entidx += 1
            ent_freq[ent_reidx[h]] += 1
            ent_freq[ent_reidx[t]] += 1
            rel_freq[r] += 1
            triples_reidx.append([ent_reidx[h], r, ent_reidx[t]])

        # randomly get query triples
        que_tris = []
        sup_tris = []
        for idx, tri in enumerate(triples_reidx):
            h, r, t = tri
            if ent_freq[h] > 2 and ent_freq[t] > 2 and rel_freq[r] > 2:
                que_tris.append(tri)
                ent_freq[h] -= 1
                ent_freq[t] -= 1
                rel_freq[r] -= 1
            else:
                sup_tris.append(tri)

            if len(que_tris) >= int(len(triples_reidx)*0.1):
                break

        sup_tris.extend(triples_reidx[idx+1:])

        if len(que_tris) >= int(len(triples_reidx)*0.05):
            break

    # hr2t, rt2h
    hr2t, rt2h = get_hr2t_rt2h_sup_que(sup_tris, que_tris)

    return sup_tris, que_tris, hr2t, rt2h


def get_average_subgraph_size(args, sample_size, bg_train_g):
    total_size = 0
    for i in range(sample_size):
        datum = sample_one_subgraph(args, bg_train_g)
        total_size += len(serialize(datum))
    return total_size / sample_size