import torch.nn as nn import torch.nn.functional as F from dgl.nn import GraphConv,SAGEConv class LogReg(nn.Module): def __init__(self, hid_dim, out_dim): super(LogReg, self).__init__() self.fc = nn.Linear(hid_dim, out_dim) def forward(self, x): ret = self.fc(x) return ret class MLP(nn.Module): def __init__(self, nfeat, nhid, nclass, use_bn=True): super(MLP, self).__init__() self.layer1 = nn.Linear(nfeat, nhid, bias=True) self.layer2 = nn.Linear(nhid, nclass, bias=True) self.bn = nn.BatchNorm1d(nhid) self.use_bn = use_bn self.act_fn = nn.ReLU() def forward(self, _, x): x = self.layer1(x) if self.use_bn: x = self.bn(x) x = self.act_fn(x) x = self.layer2(x) return x class SageConv(nn.Module): def __init__(self, input_dim, hidden_dim, out_dim, num_layers): super(SageConv, self).__init__() self.activation = nn.ReLU self.layers = nn.ModuleList() self.layers.append(SAGEConv(input_dim, hidden_dim,aggregator_type='mean', bias=True)) for _ in range(num_layers - 1): self.layers.append(SAGEConv(hidden_dim, hidden_dim,aggregator_type='mean', bias=True)) self.convs.append(SAGEConv(hidden_dim, out_dim,aggregator_type='mean', bias=True)) self.dropout = nn.Dropout(p=0) def forward(self, x, edge_index, edge_weight=None): z = x for i, conv in enumerate(self.layers): z = conv(edge_index,z) z = self.activation(z) z = self.dropout(z) return z class GCN(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, n_layers): super().__init__() self.n_layers = n_layers self.convs = nn.ModuleList() self.convs.append(GraphConv(in_dim, hid_dim, norm='both')) if n_layers > 1: for i in range(n_layers - 2): self.convs.append(GraphConv(hid_dim, hid_dim, norm='both')) self.convs.append(GraphConv(hid_dim, out_dim, norm='both')) def forward(self, graph, x): for i in range(self.n_layers - 1): x = F.relu(self.convs[i](graph, x)) x = self.convs[-1](graph, x) return x class CCA_SSG(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, n_layers, use_mlp = False): super().__init__() if not use_mlp: self.backbone = GCN(in_dim, hid_dim, out_dim, n_layers) else: self.backbone = MLP(in_dim, hid_dim, out_dim) def get_embedding(self, graph, feat): out = self.backbone(graph, feat) return out.detach() def forward(self, graph1, feat1, graph2, feat2): h1 = self.backbone(graph1, feat1) h2 = self.backbone(graph2, feat2) z1 = (h1 - h1.mean(0)) / h1.std(0) z2 = (h2 - h2.mean(0)) / h2.std(0) return z1, z2 class CCA_SSG_inductive(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, n_layers, use_mlp = False): super().__init__() if not use_mlp: self.backbone = SageConv(in_dim, hid_dim, out_dim, n_layers) else: self.backbone = MLP(in_dim, hid_dim, out_dim) def get_embedding(self, graph, feat): out = self.backbone(graph, feat) return out.detach() def forward(self, graph1, feat1, graph2, feat2): h1 = self.backbone(graph1, feat1) h2 = self.backbone(graph2, feat2) z1 = (h1 - h1.mean(0)) / h1.std(0) z2 = (h2 - h2.mean(0)) / h2.std(0) return z1, z2 class CCA_SSG_white(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, n_layers, use_mlp = False): super().__init__() if not use_mlp: self.backbone = GCN(in_dim, hid_dim, out_dim, n_layers) else: self.backbone = MLP(in_dim, hid_dim, out_dim) def get_embedding(self, graph, feat): out = self.backbone(graph, feat) return out.detach() def forward(self, graph1, feat1, graph2, feat2): h1 = self.backbone(graph1, feat1) h2 = self.backbone(graph2, feat2) z1 = (h1 - h1.mean(0)) / h1.std(0) z2 = (h2 - h2.mean(0)) / h2.std(0) return z1, z2,h1.detach().cpu().numpy(),h2.detach().cpu().numpy()