KGTOSA / GNN-Methods / NodeClassifcation / IBS / models / RGCN.py
RGCN.py
Raw
from copy import copy
import torch
import torch.nn.functional as F
from torch.nn import ModuleList, Linear, ParameterDict, Parameter
from torch_sparse import SparseTensor
from torch_geometric.nn import MessagePassing
from .evaluate import Evaluator as Evaluator
class RGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_node_types,
                 num_edge_types):
        super(RGCNConv, self).__init__(aggr='mean')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        self.rel_lins = ModuleList([
            Linear(in_channels, out_channels, bias=False)
            for _ in range(num_edge_types)
        ])

        self.root_lins = ModuleList([
            Linear(in_channels, out_channels, bias=True)
            for _ in range(num_node_types)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.rel_lins:
            lin.reset_parameters()
        for lin in self.root_lins:
            lin.reset_parameters()

    def forward(self, x, edge_index, edge_type, node_type):
        out = x.new_zeros(x.size(0), self.out_channels)

        for i in range(self.num_edge_types):
            mask = edge_type == i
            out.add_(self.propagate(edge_index[:, mask], x=x, edge_type=i))

        for i in range(self.num_node_types):
            mask = node_type == i
            out[mask] += self.root_lins[i](x[mask])

        return out

    def message(self, x_j, edge_type: int):
        return self.rel_lins[edge_type](x_j)

class RGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, num_nodes_dict, x_types, num_edge_types):
        super(RGCN, self).__init__()

        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.dropout = dropout

        node_types = list(num_nodes_dict.keys())
        num_node_types = len(node_types)

        self.num_node_types = num_node_types
        self.num_edge_types = num_edge_types

        # Create embeddings for all node types that do not come with features.
        self.emb_dict = ParameterDict({
            f'{key}': Parameter(torch.Tensor(num_nodes_dict[key], in_channels))
            for key in set(node_types).difference(set(x_types))
        })

        I, H, O = in_channels, hidden_channels, out_channels  # noqa

        # Create `num_layers` many message passing layers.
        self.convs = ModuleList()
        self.convs.append(RGCNConv(I, H, num_node_types, num_edge_types))
        for _ in range(num_layers - 2):
            self.convs.append(RGCNConv(H, H, num_node_types, num_edge_types))
        self.convs.append(RGCNConv(H, O, self.num_node_types, num_edge_types))

        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.emb_dict.values():
            torch.nn.init.xavier_uniform_(emb)
        for conv in self.convs:
            conv.reset_parameters()

    def group_input(self, x_dict, node_type, local_node_idx):
        # Create global node feature matrix.
        h = torch.zeros((node_type.size(0), self.in_channels),
                        device=node_type.device)

        for key, x in x_dict.items():
            mask = node_type == key
            h[mask] = x[local_node_idx[mask]]

        for key, emb in self.emb_dict.items():
            mask = node_type == int(key)
            h[mask] = emb[local_node_idx[mask]]

        return h

    def forward(self, x_dict, edge_index, edge_type, node_type,
                local_node_idx):

        x = self.group_input(x_dict, node_type, local_node_idx)

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, edge_type, node_type)
            if i != self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)

        return x.log_softmax(dim=-1)

    def inference(self, x_dict, edge_index_dict, key2int):
        device = list(x_dict.values())[0].device

        x_dict = copy(x_dict)

        # paper_count=len(x_dict[2])
        # paper_count = len(x_dict[3])
        for key, emb in self.emb_dict.items():
            x_dict[int(key)] = emb
            # print(key," size=",x_dict[int(key)].size())

        # print(key2int)
        # print("x_dict keys=",x_dict.keys())

        adj_t_dict = {}
        for key, (row, col) in edge_index_dict.items():
            adj_t_dict[key] = SparseTensor(row=col, col=row).to(device)
            # print(key,adj_t_dict[key].size)

        for i, conv in enumerate(self.convs):
            out_dict = {}

            for j, x in x_dict.items():
                out_dict[j] = conv.root_lins[j](x)

            for keys, adj_t in adj_t_dict.items():
                src_key, target_key = keys[0], keys[-1]
                out = out_dict[key2int[target_key]]
                # print("keys=",keys)
                # print("adj_t=",adj_t)
                # print("key2int[src_key]=",key2int[src_key])
                # print("x_dict[key2int[src_key]]=",x_dict[key2int[src_key]].size())
                tmp = adj_t.matmul(x_dict[key2int[src_key]], reduce='mean')
                # print("out size=",out.size())
                # print("tmp size=",conv.rel_lins[key2int[keys]](tmp).size())
                ################## fill missed rows hsh############################
                tmp = conv.rel_lins[key2int[keys]](tmp).resize_([out.size()[0], out.size()[1]])
                out.add_(tmp)

            if i != self.num_layers - 1:
                for j in range(self.num_node_types):
                    F.relu_(out_dict[j])

            x_dict = out_dict

        return x_dict

def rgcn_test(model,x_dict,data,key2int,train_nodes,val_nodes,test_nodes,subject_node='paper'):
    evaluator = Evaluator(name='ogbn-mag')
    model.eval()
    outputs = model(x_dict, data.edge_index, data.edge_attr, data.node_type, data.local_node_idx)
    outputs = outputs[data.node_type == key2int[subject_node]]
    y_pred = torch.argmax(outputs, dim=1)
    y_pred=y_pred.reshape(len(y_pred), 1)
    y_true = data.y[data.node_type == key2int[subject_node]].squeeze()
    y_true=y_true.reshape(len(y_true), 1)
    print("y_true[train_nodes]=",y_true[train_nodes].shape,y_true[train_nodes])
    print("y_pred[train_nodes]=",y_pred[train_nodes].shape,y_pred[train_nodes])

    train_acc = evaluator.eval({
        'y_true': y_true[train_nodes],
        'y_pred': y_pred[train_nodes],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': y_true[val_nodes],
        'y_pred': y_pred[val_nodes],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': y_true[test_nodes],
        'y_pred': y_pred[test_nodes],
    })['acc']

    # evaluator = Evaluator(name='ogbn-mag')
    # model.eval()
    # out = model.inference(x_dict, data.edge_index_dict, key2int)
    # out = out[key2int[subject_node]]
    #
    # y_pred = out.argmax(dim=-1, keepdim=True).cpu()
    # y_true = data.y_dict[subject_node]
    # # split_idx = data.get_idx_split('time')
    #
    # train_acc = evaluator.eval({
    #     'y_true': y_true[split_idx['train'][subject_node]],
    #     'y_pred': y_pred[split_idx['train'][subject_node]],
    # })['acc']
    # valid_acc = evaluator.eval({
    #     'y_true': y_true[split_idx['valid'][subject_node]],
    #     'y_pred': y_pred[split_idx['valid'][subject_node]],
    # })['acc']
    # test_acc = evaluator.eval({
    #     'y_true': y_true[split_idx['test'][subject_node]],
    #     'y_pred': y_pred[split_idx['test'][subject_node]],
    # })['acc']
    return train_acc, valid_acc, test_acc