KGTOSA / GNN-Methods / NodeClassifcation / SeHGNN / hgb / model.py
model.py
Raw
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor


class Transformer(nn.Module):
    "Self attention layer for `n_channels`."
    def __init__(self, n_channels, num_heads=1, att_drop=0., act='none'):
        super(Transformer, self).__init__()
        self.n_channels = n_channels
        self.num_heads = num_heads
        assert self.n_channels % (self.num_heads * 4) == 0

        self.query = nn.Linear(self.n_channels, self.n_channels//4)
        self.key   = nn.Linear(self.n_channels, self.n_channels//4)
        self.value = nn.Linear(self.n_channels, self.n_channels)

        self.gamma = nn.Parameter(torch.tensor([0.]))
        self.att_drop = nn.Dropout(att_drop)
        if act == 'sigmoid':
            self.act = torch.nn.Sigmoid()
        elif act == 'relu':
            self.act = torch.nn.ReLU()
        elif act == 'leaky_relu':
            self.act = torch.nn.LeakyReLU(0.2)
        elif act == 'none':
            self.act = lambda x: x
        else:
            assert 0, f'Unrecognized activation function {act} for class Transformer'

    def reset_parameters(self):

        def xavier_uniform_(tensor, gain=1.):
            fan_in, fan_out = tensor.size()[-2:]
            std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
            a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
            return torch.nn.init._no_grad_uniform_(tensor, -a, a)

        gain = nn.init.calculate_gain("relu")
        xavier_uniform_(self.query.weight, gain=gain)
        xavier_uniform_(self.key.weight, gain=gain)
        xavier_uniform_(self.value.weight, gain=gain)
        nn.init.zeros_(self.query.bias)
        nn.init.zeros_(self.key.bias)
        nn.init.zeros_(self.value.bias)

    def forward(self, x, mask=None):
        B, M, C = x.size() # batchsize, num_metapaths, channels
        H = self.num_heads
        if mask is not None:
            assert mask.size() == torch.Size((B, M))

        f = self.query(x).view(B, M, H, -1).permute(0,2,1,3) # [B, H, M, -1]
        g = self.key(x).view(B, M, H, -1).permute(0,2,3,1)   # [B, H, -1, M]
        h = self.value(x).view(B, M, H, -1).permute(0,2,1,3) # [B, H, M, -1]

        beta = F.softmax(self.act(f @ g / math.sqrt(f.size(-1))), dim=-1) # [B, H, M, M(normalized)]
        beta = self.att_drop(beta)
        if mask is not None:
            beta = beta * mask.view(B, 1, 1, M)
            beta = beta / (beta.sum(-1, keepdim=True) + 1e-12)

        o = self.gamma * (beta @ h) # [B, H, M, -1]
        return o.permute(0,2,1,3).reshape((B, M, C)) + x


class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128, independent_attn=False):
        super(SemanticAttention, self).__init__()
        self.independent_attn = independent_attn

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

    def forward(self, z): # z: (N, M, D)
        w = self.project(z) # (N, M, 1)
        if not self.independent_attn:
            w = w.mean(0, keepdim=True)  # (1, M, 1)
        beta = torch.softmax(w, dim=1)  # (N, M, 1) or (1, M, 1)

        return (beta * z).sum(1)  # (N, M, D)


class Conv1d1x1(nn.Module):
    def __init__(self, cin, cout, groups, bias=True, cformat='channel-first'):
        super(Conv1d1x1, self).__init__()
        self.cin = cin
        self.cout = cout
        self.groups = groups
        self.cformat = cformat
        if not bias:
            self.bias = None
        if self.groups == 1: # different keypoints share same kernel
            self.W = nn.Parameter(torch.randn(self.cin, self.cout))
            if bias:
                self.bias = nn.Parameter(torch.zeros(1, self.cout))
        else:
            self.W = nn.Parameter(torch.randn(self.groups, self.cin, self.cout))
            if bias:
                self.bias = nn.Parameter(torch.zeros(self.groups, self.cout))

    def reset_parameters(self):

        def xavier_uniform_(tensor, gain=1.):
            fan_in, fan_out = tensor.size()[-2:]
            std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
            a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
            return torch.nn.init._no_grad_uniform_(tensor, -a, a)

        gain = nn.init.calculate_gain("relu")
        xavier_uniform_(self.W, gain=gain)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, x):
        if self.groups == 1:
            if self.cformat == 'channel-first':
                return torch.einsum('bcm,mn->bcn', x, self.W) + self.bias
            elif self.cformat == 'channel-last':
                return torch.einsum('bmc,mn->bnc', x, self.W) + self.bias.T
            else:
                assert False
        else:
            if self.cformat == 'channel-first':
                return torch.einsum('bcm,cmn->bcn', x, self.W) + self.bias
            elif self.cformat == 'channel-last':
                return torch.einsum('bmc,cmn->bnc', x, self.W) + self.bias.T
            else:
                assert False


class SeHGNN(nn.Module):
    def __init__(self, nfeat, hidden, nclass, feat_keys, label_feat_keys, tgt_type,
                 dropout, input_drop, att_dropout, label_drop,
                 n_layers_1, n_layers_2, act,
                 residual=False, bns=False, data_size=None, drop_metapath=0., num_heads=1,
                 remove_transformer=True, independent_attn=False):
        super(SeHGNN, self).__init__()
        self.feat_keys = sorted(feat_keys)
        self.label_feat_keys = sorted(label_feat_keys)
        self.num_channels = num_channels = len(self.feat_keys) + len(self.label_feat_keys)
        self.tgt_type = tgt_type
        self.residual = residual
        self.remove_transformer = remove_transformer

        self.data_size = data_size
        self.embeding = nn.ParameterDict({})
        for k, v in data_size.items():
            self.embeding[str(k)] = nn.Parameter(
                torch.Tensor(v, nfeat).uniform_(-0.5, 0.5))

        if len(self.label_feat_keys):
            self.labels_embeding = nn.ParameterDict({})
            for k in self.label_feat_keys:
                self.labels_embeding[k] = nn.Parameter(
                    torch.Tensor(nclass, nfeat).uniform_(-0.5, 0.5))
        else:
            self.labels_embeding = {}

        self.layers = nn.Sequential(
            Conv1d1x1(nfeat, hidden, num_channels, bias=True, cformat='channel-first'),
            nn.LayerNorm([num_channels, hidden]),
            nn.PReLU(),
            nn.Dropout(dropout),
            # Conv1d1x1(num_channels, num_channels, hidden, bias=True, cformat='channel-last'),
            # nn.LayerNorm([num_channels, hidden]),
            # nn.PReLU(),
            Conv1d1x1(hidden, hidden, num_channels, bias=True, cformat='channel-first'),
            nn.LayerNorm([num_channels, hidden]),
            nn.PReLU(),
            nn.Dropout(dropout),
            # Conv1d1x1(num_channels, 4, 1, bias=True, cformat='channel-last'),
            # nn.LayerNorm([4, hidden]),
            # nn.PReLU(),
        )

        if self.remove_transformer:
            self.layer_mid = SemanticAttention(hidden, hidden, independent_attn=independent_attn)
            self.layer_final = nn.Linear(hidden, hidden)
        else:
            self.layer_mid = Transformer(hidden, num_heads=num_heads)
            # self.layer_mid = Transformer(hidden, drop_metapath=drop_metapath)
            self.layer_final = nn.Linear(num_channels * hidden, hidden)

        if self.residual:
            self.res_fc = nn.Linear(nfeat, hidden, bias=False)

        def add_nonlinear_layers(nfeats, dropout, bns=False):
            return [
                nn.BatchNorm1d(hidden, affine=bns, track_running_stats=bns),
                nn.PReLU(),
                nn.Dropout(dropout)
            ]

        lr_output_layers = [
            [nn.Linear(hidden, hidden, bias=not bns)] + add_nonlinear_layers(hidden, dropout, bns)
            # [Dense(hidden, hidden, bns)] + add_nonlinear_layers(hidden, dropout, bns)
            for _ in range(n_layers_2-1)]
        self.lr_output = nn.Sequential(*(
            [ele for li in lr_output_layers for ele in li] + [
            nn.Linear(hidden, nclass, bias=False),
            nn.BatchNorm1d(nclass, affine=bns, track_running_stats=bns)]))
        # self.lr_output = FeedForwardNetII(
        #     hidden, hidden, nclass, n_layers_2, dropout, 0.5, bns)

        self.prelu = nn.PReLU()
        self.dropout = nn.Dropout(dropout)
        self.input_drop = nn.Dropout(input_drop)
        self.att_drop = nn.Dropout(att_dropout)
        self.label_drop = nn.Dropout(label_drop)

        self.reset_parameters()

    def reset_parameters(self):
        # for k, v in self.embeding:
        #     v.data.uniform_(-0.5, 0.5)
        # for k, v in self.labels_embeding:
        #     v.data.uniform_(-0.5, 0.5)

        for layer in self.layers:
            if isinstance(layer, Conv1d1x1):
                layer.reset_parameters()

        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.layer_final.weight, gain=gain)
        nn.init.zeros_(self.layer_final.bias)
        if self.residual:
            nn.init.xavier_uniform_(self.res_fc.weight, gain=gain)
        for layer in self.lr_output:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=gain)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

    def forward(self, batch, feature_dict, label_dict={}, mask=None):
        if isinstance(feature_dict[self.tgt_type], torch.Tensor):
            mapped_feats = {k: self.input_drop(x @ self.embeding[k]) for k, x in feature_dict.items()}
        elif isinstance(feature_dict[self.tgt_type], SparseTensor):
            mapped_feats = {k: self.input_drop(x @ self.embeding[k[-1]]) for k, x in feature_dict.items()}
        else:
            assert 0

        mapped_label_feats = {k: self.input_drop(x @ self.labels_embeding[k]) for k, x in label_dict.items()}

        features = [mapped_feats[k] for k in self.feat_keys] + [mapped_label_feats[k] for k in self.label_feat_keys]

        # if mask is not None:
        #     if self.keys_sort:
        #         mask = [mask[k] for k in sorted_f_keys]
        #     else:
        #         mask = list(mask.values())
        #     mask = torch.stack(mask, dim=1)

        #     if len(label_list):
        #         if mask is not None:
        #             mask = torch.cat((mask, torch.ones((len(mask), len(label_list))).to(device=mask.device)), dim=1)

        B = num_node = mapped_feats[self.tgt_type].shape[0]
        C = self.num_channels
        D = mapped_feats[self.tgt_type].shape[1]

        features = torch.stack(features, dim=1) # [B, C, D]

        features = self.layers(features)
        if self.remove_transformer:
            features = self.layer_mid(features)
        else:
            features = self.layer_mid(features, mask=None).transpose(1,2)
        out = self.layer_final(features.reshape(B, -1))

        if self.residual:
            out = out + self.res_fc(mapped_feats[self.tgt_type])
        out = self.dropout(self.prelu(out))
        out = self.lr_output(out)
        return out


from torch_geometric.nn import GATConv
class SeHGNN_NA(nn.Module):
    def __init__(self, nfeat, hidden, nclass, feat_keys, label_feat_keys, tgt_type,
                 dropout, input_drop, att_dropout, label_drop,
                 n_layers_1, n_layers_2, act,
                 residual=False, bns=False, data_size=None, drop_metapath=0., num_heads=1):
        super(SeHGNN_NA, self).__init__()
        self.feat_keys = sorted(feat_keys)
        self.label_feat_keys = sorted(label_feat_keys)
        self.num_channels = num_channels = len(self.feat_keys) + 1 + len(self.label_feat_keys)
        self.tgt_type = tgt_type
        self.residual = residual

        self.data_size = data_size
        self.embeding = nn.ParameterDict({})
        for k, v in data_size.items():
            self.embeding[str(k)] = nn.Parameter(
                torch.Tensor(v, nfeat).uniform_(-0.5, 0.5))

        if len(self.label_feat_keys):
            self.labels_embeding = nn.ParameterDict({})
            for k in self.label_feat_keys:
                self.labels_embeding[k] = nn.Parameter(
                    torch.Tensor(nclass, nfeat).uniform_(-0.5, 0.5))
        else:
            self.labels_embeding = {}

        self.layers = nn.Sequential(
            Conv1d1x1(nfeat, hidden, num_channels, bias=True, cformat='channel-first'),
            nn.LayerNorm([num_channels, hidden]),
            nn.PReLU(),
            nn.Dropout(dropout),
            # Conv1d1x1(num_channels, num_channels, hidden, bias=True, cformat='channel-last'),
            # nn.LayerNorm([num_channels, hidden]),
            # nn.PReLU(),
            Conv1d1x1(hidden, hidden, num_channels, bias=True, cformat='channel-first'),
            nn.LayerNorm([num_channels, hidden]),
            nn.PReLU(),
            nn.Dropout(dropout),
            # Conv1d1x1(num_channels, 4, 1, bias=True, cformat='channel-last'),
            # nn.LayerNorm([4, hidden]),
            # nn.PReLU(),
        )

        assert hidden % num_heads == 0, f'hidden {hidden} should be divisible by num_heads {num_heads}'
        self.feat_gat_layers = nn.ModuleDict({
            k: GATConv(hidden, hidden//num_heads, heads=num_heads, dropout=0.6) for k in self.feat_keys})

        self.layer_mid = Transformer(hidden, num_heads=1)
        # self.layer_mid = Transformer(hidden, drop_metapath=drop_metapath)
        self.layer_final = nn.Linear(num_channels * hidden, hidden)

        if self.residual:
            self.res_fc = nn.Linear(nfeat, hidden, bias=False)

        def add_nonlinear_layers(nfeats, dropout, bns=False):
            return [
                nn.BatchNorm1d(hidden, affine=bns, track_running_stats=bns),
                nn.PReLU(),
                nn.Dropout(dropout)
            ]

        lr_output_layers = [
            [nn.Linear(hidden, hidden, bias=not bns)] + add_nonlinear_layers(hidden, dropout, bns)
            # [Dense(hidden, hidden, bns)] + add_nonlinear_layers(hidden, dropout, bns)
            for _ in range(n_layers_2-1)]
        self.lr_output = nn.Sequential(*(
            [ele for li in lr_output_layers for ele in li] + [
            nn.Linear(hidden, nclass, bias=False),
            nn.BatchNorm1d(nclass, affine=bns, track_running_stats=bns)]))
        # self.lr_output = FeedForwardNetII(
        #     hidden, hidden, nclass, n_layers_2, dropout, 0.5, bns)

        self.prelu = nn.PReLU()
        self.dropout = nn.Dropout(dropout)
        self.input_drop = nn.Dropout(input_drop)
        self.att_drop = nn.Dropout(att_dropout)
        self.label_drop = nn.Dropout(label_drop)

        self.reset_parameters()

    def reset_parameters(self):
        # for k, v in self.embeding:
        #     v.data.uniform_(-0.5, 0.5)
        # for k, v in self.labels_embeding:
        #     v.data.uniform_(-0.5, 0.5)

        for layer in self.layers:
            if isinstance(layer, Conv1d1x1):
                layer.reset_parameters()
        for k, v in self.feat_gat_layers.items():
            v.reset_parameters()

        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.layer_final.weight, gain=gain)
        nn.init.zeros_(self.layer_final.bias)
        if self.residual:
            nn.init.xavier_uniform_(self.res_fc.weight, gain=gain)
        for layer in self.lr_output:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=gain)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

    def forward(self, batch, feat_adjs, label_dict={}, mask=None):
        mapped_feats = {str(k): self.input_drop(x @ self.embeding[str(k)]) for k, x in self.feats.items()}

        aggregated_feats = {}
        for k in self.feat_gat_layers.keys():
            tgt, src = k[0], k[-1]
            assert len(k) > 1 and tgt == self.tgt_type
            aggregated_feats[k] = self.feat_gat_layers[k]((
                mapped_feats[src], mapped_feats[self.tgt_type][batch]), feat_adjs[k])
        aggregated_feats[self.tgt_type] = mapped_feats[self.tgt_type][batch]

        mapped_label_feats = {k: self.input_drop(x @ self.labels_embeding[k]) for k, x in label_dict.items()}

        features = [aggregated_feats[k] for k in self.feat_keys] + [aggregated_feats[self.tgt_type]] + [mapped_label_feats[k] for k in self.label_feat_keys]

        # if mask is not None:
        #     if self.keys_sort:
        #         mask = [mask[k] for k in sorted_f_keys]
        #     else:
        #         mask = list(mask.values())
        #     mask = torch.stack(mask, dim=1)

        B = num_node = len(batch)
        C = self.num_channels
        D = aggregated_feats[self.tgt_type].shape[1]

        features = torch.stack(features, dim=1) # [B, C, D]

        features = self.layers(features)
        features = self.layer_mid(features, mask=None).transpose(1,2)
        out = self.layer_final(features.reshape(B, -1))

        if self.residual:
            out = out + self.res_fc(aggregated_feats[self.tgt_type])
        out = self.dropout(self.prelu(out))
        out = self.lr_output(out)
        return out


class SeHGNN_block(nn.Module):
    def __init__(self, nfeat, hidden, feat_keys, label_feat_keys, tgt_type, dropout, num_fp_layers=2):
        super(SeHGNN_block, self).__init__()
        self.feat_keys = sorted(feat_keys)
        self.label_feat_keys = sorted(label_feat_keys)
        self.tgt_type = tgt_type
        self.num_channels = num_channels = len(self.feat_keys) + len(self.label_feat_keys)

        def add_fp_layers(cin, cout):
            return [
                Conv1d1x1(cin, cout, num_channels, bias=True, cformat='channel-first'),
                nn.LayerNorm([num_channels, cout]),
                nn.PReLU(),
                nn.Dropout(dropout)
            ]

        layers = add_fp_layers(nfeat, hidden)
        for i in range(1, num_fp_layers):
            layers += add_fp_layers(hidden, hidden)
        self.layers = nn.Sequential(*layers)

        self.layer_mid = Transformer(hidden, num_heads=1)
        # self.layer_mid = Transformer(hidden, drop_metapath=drop_metapath)
        self.layer_final = nn.Linear(num_channels * hidden, hidden)

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        for layer in self.layers:
            if isinstance(layer, Conv1d1x1):
                layer.reset_parameters()
        self.layer_mid.reset_parameters()
        nn.init.xavier_uniform_(self.layer_final.weight, gain=gain)
        nn.init.zeros_(self.layer_final.bias)

    def forward(self, x, label_dict={}, mask=None):
        x = torch.stack(
            [x[k] for k in self.feat_keys] + [label_dict[k] for k in self.label_feat_keys], dim=1)
        B, C, D = x.size()
        assert C == self.num_channels

        x = self.layers(x)
        x = self.layer_mid(x, mask=mask).transpose(1,2)
        return self.layer_final(x.reshape(B, -1))


class SeHGNN_2L(nn.Module):
    def __init__(self, nfeat, hidden, nclass, feat_keys1, feat_keys2, label_feat_keys, node_types,
                 dropout, input_drop, att_dropout, label_drop,
                 n_layers_1, n_layers_2, act,
                 residual=False, bns=False, data_size=None, drop_metapath=0., num_heads=1):
        super(SeHGNN_2L, self).__init__()
        self.label_feat_keys = sorted(label_feat_keys)
        self.node_types = node_types
        self.tgt_type = self.node_types[0]
        self.residual = residual

        self.feat_keys1 = {}
        for tgt in self.node_types:
            self.feat_keys1[tgt] = [k for k in feat_keys1 if k[0] == tgt] + [tgt]

        self.layer1_blocks = nn.ModuleDict({
            tgt: SeHGNN_block(nfeat, hidden, self.feat_keys1[tgt], {}, tgt, dropout) for tgt in self.node_types})

        self.feat_keys2 = sorted(feat_keys2) + [self.tgt_type]

        self.layer2 = SeHGNN_block(hidden, hidden, self.feat_keys2, self.label_feat_keys, self.tgt_type, dropout)

        self.data_size = data_size
        self.embeding = nn.ParameterDict({})
        for k, v in data_size.items():
            self.embeding[str(k)] = nn.Parameter(
                torch.Tensor(v, nfeat).uniform_(-0.5, 0.5))

        if len(self.label_feat_keys):
            self.labels_embeding = nn.ParameterDict({})
            for k in self.label_feat_keys:
                self.labels_embeding[k] = nn.Parameter(
                    torch.Tensor(nclass, nfeat).uniform_(-0.5, 0.5))
        else:
            self.labels_embeding = {}

        if self.residual:
            self.res_fc = nn.Linear(nfeat, hidden, bias=False)

        def add_nonlinear_layers(nfeats, dropout, bns=False):
            return [
                nn.BatchNorm1d(hidden, affine=bns, track_running_stats=bns),
                nn.PReLU(),
                nn.Dropout(dropout)
            ]

        lr_output_layers = [
            [nn.Linear(hidden, hidden, bias=not bns)] + add_nonlinear_layers(hidden, dropout, bns)
            # [Dense(hidden, hidden, bns)] + add_nonlinear_layers(hidden, dropout, bns)
            for _ in range(n_layers_2-1)]
        self.lr_output = nn.Sequential(*(
            [ele for li in lr_output_layers for ele in li] + [
            nn.Linear(hidden, nclass, bias=False),
            nn.BatchNorm1d(nclass, affine=bns, track_running_stats=bns)]))
        # self.lr_output = FeedForwardNetII(
        #     hidden, hidden, nclass, n_layers_2, dropout, 0.5, bns)

        self.prelu = nn.PReLU()
        self.dropout = nn.Dropout(dropout)
        self.input_drop = nn.Dropout(input_drop)
        self.att_drop = nn.Dropout(att_dropout)
        self.label_drop = nn.Dropout(label_drop)

        self.reset_parameters()

    def reset_parameters(self):
        # for k, v in self.embeding:
        #     v.data.uniform_(-0.5, 0.5)
        # for k, v in self.labels_embeding:
        #     v.data.uniform_(-0.5, 0.5)

        for layer in self.layer1_blocks.values():
            layer.reset_parameters()
        self.layer2.reset_parameters()

        gain = nn.init.calculate_gain("relu")
        if self.residual:
            nn.init.xavier_uniform_(self.res_fc.weight, gain=gain)
        for layer in self.lr_output:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=gain)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

    def forward(self, layer1_feats, batch1, layer2_feats, batch2, label_dict={}):
        mapped_raw_feats = {str(k): self.input_drop(x @ self.embeding[str(k)]) for k, x in self.feats.items()}

        aggregated_feats = {k: v @ mapped_raw_feats[k[-1]] for k, v in layer1_feats.items()}
        for k, batch in batch1.items():
            aggregated_feats[k] = mapped_raw_feats[k][batch]

        features = {}
        for tgt in batch1.keys():
            features[tgt] = self.layer1_blocks[tgt]({k: v for k, v in aggregated_feats.items() if k[0] == tgt}, {})

        aggregated_feats = {k: v @ features[k[-1]] for k, v in layer2_feats.items()}
        aggregated_feats[self.tgt_type] = mapped_raw_feats[self.tgt_type][batch2]

        mapped_label_feats = {k: self.input_drop(x @ self.labels_embeding[k]) for k, x in label_dict.items()}

        out = self.layer2(aggregated_feats, mapped_label_feats)

        if self.residual:
            out = out + self.res_fc(aggregated_feats[self.tgt_type])
        out = self.dropout(self.prelu(out))
        out = self.lr_output(out)
        return out