import torch as th torch = th from torch import nn import torch.nn.functional as F from dgl import function as fn from dgl.nn.pytorch import edge_softmax from dgl._ffi.base import DGLError from dgl.nn.pytorch.utils import Identity from dgl.utils import expand_as_pair class myGATConv(nn.Module): def __init__(self, edge_feats, num_etypes, in_feats, out_feats, num_heads, feat_drop=0., attn_drop=0., negative_slope=0.2, residual=False, activation=None, allow_zero_in_degree=False, bias=False, alpha=0.): super(myGATConv, self).__init__() self._edge_feats = edge_feats self._num_heads = num_heads self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._allow_zero_in_degree = allow_zero_in_degree self.edge_emb = nn.Embedding(num_etypes, edge_feats) if isinstance(in_feats, tuple): self.fc_src = nn.Linear( self._in_src_feats, out_feats * num_heads, bias=False) self.fc_dst = nn.Linear( self._in_dst_feats, out_feats * num_heads, bias=False) else: self.fc = nn.Linear( self._in_src_feats, out_feats * num_heads, bias=False) self.fc_e = nn.Linear(edge_feats, edge_feats*num_heads, bias=False) self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_e = nn.Parameter(th.FloatTensor(size=(1, num_heads, edge_feats))) self.feat_drop = nn.Dropout(feat_drop) self.attn_drop = nn.Dropout(attn_drop) self.leaky_relu = nn.LeakyReLU(negative_slope) if residual: if self._in_dst_feats != out_feats: self.res_fc = nn.Linear( self._in_dst_feats, num_heads * out_feats, bias=False) else: self.res_fc = Identity() else: self.register_buffer('res_fc', None) self.reset_parameters() self.activation = activation self.bias = bias if bias: self.bias_param = nn.Parameter(th.zeros((1, num_heads, out_feats))) self.alpha = alpha def reset_parameters(self): gain = nn.init.calculate_gain('relu') if hasattr(self, 'fc'): nn.init.xavier_normal_(self.fc.weight, gain=gain) else: nn.init.xavier_normal_(self.fc_src.weight, gain=gain) nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain) nn.init.xavier_normal_(self.attn_e, gain=gain) if isinstance(self.res_fc, nn.Linear): nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.fc_e.weight, gain=gain) def set_allow_zero_in_degree(self, set_value): self._allow_zero_in_degree = set_value def forward(self, graph, feat, e_feat0, res_attn=None): with graph.local_scope(): if not self._allow_zero_in_degree: if (graph.in_degrees() == 0).any(): raise DGLError('There are 0-in-degree nodes in the graph, ' 'output for those nodes will be invalid. ' 'This is harmful for some applications, ' 'causing silent performance regression. ' 'Adding self-loop on the input graph by ' 'calling `g = dgl.add_self_loop(g)` will resolve ' 'the issue. Setting ``allow_zero_in_degree`` ' 'to be `True` when constructing this module will ' 'suppress the check and let the code run.') if isinstance(feat, tuple): h_src = self.feat_drop(feat[0]) h_dst = self.feat_drop(feat[1]) if not hasattr(self, 'fc_src'): self.fc_src, self.fc_dst = self.fc, self.fc feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) else: h_src = h_dst = self.feat_drop(feat) feat_src = feat_dst = self.fc(h_src).view( -1, self._num_heads, self._out_feats) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] e_feat = self.edge_emb(e_feat0) e_feat = self.fc_e(e_feat).view(-1, self._num_heads, self._edge_feats) ee = (e_feat * self.attn_e).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) graph.srcdata.update({'ft': feat_src, 'el': el}) graph.dstdata.update({'er': er}) graph.edata.update({'ee': ee}) graph.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(graph.edata.pop('e')+graph.edata.pop('ee')) graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) if res_attn is not None: graph.edata['a'] = graph.edata['a'] * (1-self.alpha) + res_attn * self.alpha graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) rst = graph.dstdata['ft'] if self.res_fc is not None: resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) rst = rst + resval if self.bias: rst = rst + self.bias_param if self.activation: rst = self.activation(rst) return rst, graph.edata.pop('a').detach() class myGATConv2(nn.Module): def __init__(self, edge_feats, num_etypes, in_feats, out_feats, num_heads, feat_drop=0., attn_drop=0., negative_slope=0.2, residual=False, activation=None, allow_zero_in_degree=False, bias=False, alpha=0., num_relations=None, layer_idx=-1): super(myGATConv2, self).__init__() self._edge_feats = edge_feats self._num_heads = num_heads self._num_relations = num_relations self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._allow_zero_in_degree = allow_zero_in_degree self.edge_emb = nn.Embedding(num_etypes, edge_feats) self.layer_idx = layer_idx if isinstance(in_feats, tuple): self.fc_src = nn.Linear( self._in_src_feats, out_feats * num_heads, bias=False) self.fc_dst = nn.Linear( self._in_dst_feats, out_feats * num_heads, bias=False) else: self.fc = nn.Linear( self._in_src_feats, out_feats * num_heads, bias=False) self.fc_e = nn.Linear(edge_feats, edge_feats*num_heads, bias=False) self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_e = nn.Parameter(th.FloatTensor(size=(1, num_heads, edge_feats))) self.feat_drop = nn.Dropout(feat_drop) self.attn_drop = nn.Dropout(attn_drop) self.leaky_relu = nn.LeakyReLU(negative_slope) if residual: if self._in_dst_feats != out_feats: self.res_fc = nn.Linear( self._in_dst_feats, num_heads * out_feats, bias=False) else: self.res_fc = Identity() else: self.register_buffer('res_fc', None) self.reset_parameters() self.activation = activation self.bias = bias if bias: self.bias_param = nn.Parameter(th.zeros((1, num_heads, out_feats))) self.alpha = alpha def reset_parameters(self): gain = nn.init.calculate_gain('relu') if hasattr(self, 'fc'): nn.init.xavier_normal_(self.fc.weight, gain=gain) else: nn.init.xavier_normal_(self.fc_src.weight, gain=gain) nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain) nn.init.xavier_normal_(self.attn_e, gain=gain) if isinstance(self.res_fc, nn.Linear): nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.fc_e.weight, gain=gain) def set_allow_zero_in_degree(self, set_value): self._allow_zero_in_degree = set_value def forward(self, graph, feat, e_feat0, res_attn=None, average_na=False, average_sa=False, save=False): with graph.local_scope(): if not self._allow_zero_in_degree: if (graph.in_degrees() == 0).any(): raise DGLError('There are 0-in-degree nodes in the graph, ' 'output for those nodes will be invalid. ' 'This is harmful for some applications, ' 'causing silent performance regression. ' 'Adding self-loop on the input graph by ' 'calling `g = dgl.add_self_loop(g)` will resolve ' 'the issue. Setting ``allow_zero_in_degree`` ' 'to be `True` when constructing this module will ' 'suppress the check and let the code run.') if isinstance(feat, tuple): h_src = self.feat_drop(feat[0]) h_dst = self.feat_drop(feat[1]) if not hasattr(self, 'fc_src'): self.fc_src, self.fc_dst = self.fc, self.fc feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) else: h_src = h_dst = self.feat_drop(feat) feat_src = feat_dst = self.fc(h_src).view( -1, self._num_heads, self._out_feats) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] e_feat = self.edge_emb(e_feat0) e_feat = self.fc_e(e_feat).view(-1, self._num_heads, self._edge_feats) ee = (e_feat * self.attn_e).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) graph.srcdata.update({'ft': feat_src, 'el': el}) graph.dstdata.update({'er': er}) graph.edata.update({'ee': ee}) graph.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(graph.edata.pop('e') + graph.edata.pop('ee')) graph.edata['a'] = edge_softmax(graph, e) if average_na: with torch.no_grad(): print('Processing neighbor weight average ...', flush=True) device = graph.edata['a'].device if self._num_relations is not None: num_relations = self._num_relations else: num_relations = max(e_feat0).item() + 1 src, dst, eid = graph._graph.edges(0) num_edges = len(eid) mask = F.one_hot(e_feat0).bool() graph.edata['mask'] = mask.float().to(device) tmp = th.zeros((num_edges * num_relations, self._num_heads)).to(device) tmp[mask.flatten()] = graph.edata['a'].squeeze(-1) graph.edata['tmp'] = tmp.view(num_edges, num_relations, self._num_heads) graph.update_all(fn.copy_e('mask', 'm'), fn.sum('m', 'label_counts')) graph.update_all(fn.copy_e('tmp', 'm'), fn.sum('m', 'weight_sum')) tmp = graph.ndata['weight_sum'] / (graph.ndata['label_counts'].unsqueeze(-1).to(device) + 1e-6) aa = dst * num_relations + e_feat0 graph.edata['a'] = tmp.view(-1, self._num_heads)[aa].unsqueeze(-1) elif average_sa: with torch.no_grad(): print('Processing semantic weight average ...', flush=True) device = graph.edata['a'].device if self._num_relations is not None: num_relations = self._num_relations else: num_relations = max(e_feat0).item() + 1 src, dst, eid = graph._graph.edges(0) num_edges = len(eid) mask = F.one_hot(e_feat0).bool() graph.edata['mask'] = mask.float().to(device) tmp = th.zeros((num_edges * num_relations, self._num_heads)).to(device) tmp[mask.flatten()] = graph.edata['a'].squeeze(-1) graph.edata['tmp'] = tmp.view(num_edges, num_relations, self._num_heads) graph.update_all(fn.copy_e('mask', 'm'), fn.sum('m', 'label_counts')) graph.update_all(fn.copy_e('tmp', 'm'), fn.sum('m', 'weight_sum')) basic = graph.ndata['weight_sum'].sum(-1, keepdim=True) > 1e-6 basic = (graph.ndata['weight_sum'] > 0) / (basic.sum(-2, keepdim=True) * graph.ndata['weight_sum'] + 1e-12) aa = dst * num_relations + e_feat0 graph.edata['a'] = graph.edata['a'] * basic.view(-1, self._num_heads)[aa].unsqueeze(-1) graph.edata['a'] = self.attn_drop(graph.edata['a']) if res_attn is not None: graph.edata['a'] = graph.edata['a'] * (1-self.alpha) + res_attn * self.alpha graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) rst = graph.dstdata['ft'] if self.res_fc is not None: resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) rst = rst + resval if self.bias: rst = rst + self.bias_param if self.activation: rst = self.activation(rst) return rst, graph.edata.pop('a').detach()