import torch import torch.nn as nn import torch.nn.functional as F import dgl import dgl.function as fn from dgl.nn.pytorch import GATConv, GraphConv import datetime echo_time = True class SemanticAttention(nn.Module): def __init__(self, in_size, hidden_size=128): super(SemanticAttention, self).__init__() self.project = nn.Sequential( nn.Linear(in_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 1, bias=False) ) def forward(self, z): w = self.project(z).mean(0) beta = torch.softmax(w, dim=0) beta = beta.expand((z.shape[0],) + beta.shape) return (beta * z).sum(1) # def func(x): print(x.shape, x.mean().item(), x.std().item(), x.max().item(), x.min().item(), flush=True) def edge_dropout(g, weight, p, training=True): if training == False or p >= 1: return torch.ones_like(weight) with torch.no_grad(): ones = torch.ones_like(weight) mask = F.dropout(ones, p=p, training=training) g.edata.update({'mask': torch.cat((ones, mask), dim=1)}) g.update_all(fn.copy_e('mask', 'm'), fn.sum('m', 'num_msgs')) num_msgs = g.dstdata.pop('num_msgs') num_msgs = num_msgs[:,0] / num_msgs[:,1] num_msgs[torch.isnan(num_msgs) | torch.isinf(num_msgs)] = -1 g.dstdata.update({'num_msgs': num_msgs}) g.apply_edges(lambda edges: {'mask_value': edges.dst['num_msgs']}) mask_value = g.edata.pop('mask_value').reshape(weight.shape) final_mask = mask * mask_value * (mask_value > 0).float() + (mask_value < 0).float() # g.edata.update({'mask': weight * final_mask}) # g.update_all(fn.copy_e('mask', 'm'), # fn.sum('m', 'num_msgs')) # num_msgs = g.dstdata.pop('num_msgs') # # func(num_msgs) # print(torch.unique(num_msgs, return_counts=True), flush=True) g.edata.pop('mask') return final_mask class MeanAggregator(nn.Module): def __init__(self, in_feats, out_feats, feat_drop, aggr_drop, norm='right', weight=True, bias=True, activation=None, allow_zero_in_degree=False): super(MeanAggregator, self).__init__() self.feat_drop = feat_drop self.aggr_drop = aggr_drop self.gcnconv = GraphConv(in_feats, out_feats, norm=norm, weight=weight, bias=bias, activation=activation, allow_zero_in_degree=allow_zero_in_degree) def forward(self, graph, feat, edge_weights): with graph.local_scope(): feat_src, feat_dst = dgl.utils.expand_as_pair(feat, graph) weight = self.gcnconv.weight if self.gcnconv._in_feats > self.gcnconv._out_feats: if weight is not None: if self.feat_drop > 0: feat_src = F.dropout(feat_src, p=self.feat_drop, training=self.training) feat_src = torch.matmul(feat_src, weight) graph.srcdata.update({'ft': feat_src}) graph.edata.update({'a': F.dropout( edge_weights, p=self.aggr_drop, training=self.training)}) # graph.edata.update({'a': edge_weights * edge_dropout( # graph, edge_weights, p=self.aggr_drop, training=self.training)}) graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'h')) rst = graph.dstdata['h'] else: graph.srcdata.update({'ft': feat_src}) graph.edata.update({'a': F.dropout( edge_weights, p=self.aggr_drop, training=self.training)}) # graph.edata.update({'a': edge_weights * edge_dropout( # graph, edge_weights, p=self.aggr_drop, training=self.training)}) graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'h')) rst = graph.dstdata['h'] if weight is not None: if feat_drop > 0: rst = F.dropout(rst, p=self.feat_drop, training=self.training) rst = torch.matmul(rst, weight) if self.gcnconv.bias is not None: rst = rst + self.gcnconv.bias if self.gcnconv._activation is not None: rst = self.gcnconv._activation(rst) return rst class HANLayer(nn.Module): def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout, remove_sa=False, use_gat=True, attn_dropout=0., adj_norm='right'): super(HANLayer, self).__init__() self.use_gat = use_gat self.remove_sa = remove_sa self.dropout = dropout self.attn_dropout = attn_dropout self.adj_norm = adj_norm self.gat_layers = nn.ModuleList() for i in range(len(meta_paths)): if self.use_gat: self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, dropout, attn_dropout, activation=F.elu, allow_zero_in_degree=True)) else: self.gat_layers.append(MeanAggregator(in_size, out_size, dropout, attn_dropout, norm='right', activation=F.elu, allow_zero_in_degree=True)) self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths) self._cached_graph = None self._cached_coalesced_graph = {} self._cached_edge_weight = {} def forward(self, g, h): semantic_embeddings = [] if self._cached_graph is None or self._cached_graph is not g: self._cached_graph = g self._cached_coalesced_graph.clear() if echo_time: torch.cuda.synchronize() tic = datetime.datetime.now() with torch.no_grad(): for meta_path in self.meta_paths: temp_graph = dgl.metapath_reachable_graph(g, meta_path) self._cached_coalesced_graph[meta_path] = temp_graph if not self.use_gat: if self.adj_norm == 'none': edge_weights = torch.ones(temp_graph.num_edges()) else: src, dst, eid = temp_graph.cpu()._graph.edges(0) from torch_sparse import SparseTensor adj = SparseTensor(row=dst, col=src, value=eid) perm = torch.argsort(adj.storage.value()) base_ew = 1 / adj.storage.rowcount() base_ew[torch.isnan(base_ew) | torch.isinf(base_ew)] = 0 if self.adj_norm == 'right': edge_weights = base_ew[adj.storage.row()] elif self.adj_norm == 'both': edge_weights = torch.pow(base_ew, 0.5)[adj.storage.row()] else: assert False, f'Unknown adj_norm method {self.adj_norm}' edge_weights = edge_weights[perm] self._cached_edge_weight[meta_path] = edge_weights.unsqueeze(-1).to(device=temp_graph.device) if echo_time: torch.cuda.synchronize() toc = datetime.datetime.now() print('Time used for make graph', toc-tic, self.meta_paths, self._cached_coalesced_graph) # if echo_time: # torch.cuda.synchronize() # tic = datetime.datetime.now() for i, meta_path in enumerate(self.meta_paths): new_g = self._cached_coalesced_graph[meta_path] if self.use_gat: semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) else: new_ew = self._cached_edge_weight[meta_path] semantic_embeddings.append(self.gat_layers[i](new_g, h, new_ew).flatten(1)) # if echo_time: # torch.cuda.synchronize() # toc = datetime.datetime.now() semantic_embeddings = torch.stack(semantic_embeddings, dim=1) if self.remove_sa: out = semantic_embeddings.mean(1) else: out = self.semantic_attention(semantic_embeddings) # if echo_time: # torch.cuda.synchronize() # toc2 = datetime.datetime.now() # print(toc2-toc, toc-tic) return out class HAN(nn.Module): def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout, remove_sa=False, use_gat=True, attn_dropout=0.): super(HAN, self).__init__() self.layers = nn.ModuleList() self.layers.append(HANLayer( meta_paths, in_size, hidden_size, num_heads[0], dropout, remove_sa=remove_sa, use_gat=use_gat, attn_dropout=attn_dropout)) for l in range(1, len(num_heads)): self.layers.append(HANLayer( meta_paths, hidden_size * num_heads[l-1], hidden_size, num_heads[l], dropout, remove_sa=remove_sa, use_gat=use_gat, attn_dropout=attn_dropout)) self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) def forward(self, g, h): for gnn in self.layers: h = gnn(g, h) return self.predict(h) class HAN_freebase(nn.Module): def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout, use_gat=True): super(HAN_freebase, self).__init__() self.layers = nn.ModuleList() self.fc = nn.Linear(in_size, hidden_size) self.layers.append(HANLayer(meta_paths, hidden_size, hidden_size, num_heads[0], dropout, use_gat=use_gat)) for l in range(1, len(num_heads)): self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1], hidden_size, num_heads[l], dropout, use_gat=use_gat)) self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) def forward(self, g, h): h = self.fc(h) h = torch.FloatTensor(h.cpu()).to(g.device) for gnn in self.layers: h = gnn(g, h) return self.predict(h)