import argparse import torch import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch.conv import GINConv,GraphConv,GATConv # from dgl.nn.pytorch.conv import GraphConv from dgl.nn.pytorch.glob import SumPooling from utils import * #因为要做全局表示,就没有训练,可以加上COMPLEX,HYBRID之类的label,做分类任务 #numberofglycos不用特别大 class MLP(nn.Module): """Construct two-layer MLP-type aggreator for GIN model""" def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.linears = nn.ModuleList() # two-layer MLP self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False)) self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False)) self.batch_norm = nn.BatchNorm1d((hidden_dim)) def forward(self, x): h = x h=self.linears[0](h) h = F.relu(self.batch_norm(h)) return self.linears[1](h) class GIN(nn.Module): def __init__(self, numberofglycos, hidden_dim, output_dim,init_eps): super().__init__() self.ginlayers = nn.ModuleList() self.batch_norms = nn.ModuleList() self.glyco_embedding=nn.Embedding(numberofglycos, hidden_dim, padding_idx=None) #如果输入一直MLP的话,不同的糖会呈现线性关系,这里我们采用embedding # num_layers = 5 # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme for layer in range(GNN_global_num_layers - 1): # excluding the input layer mlp = MLP(hidden_dim, hidden_dim, hidden_dim) self.ginlayers.append( GINConv(mlp, init_eps=init_eps,learn_eps=False) ) # set to True if learning epsilon self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # linear functions for graph sum poolings of output of each layer self.linear_prediction = nn.ModuleList() for layer in range(GNN_global_num_layers): self.linear_prediction.append(nn.Linear(hidden_dim, output_dim)) self.drop = nn.Dropout(0.5) self.pool = ( SumPooling() ) # change to mean readout (AvgPooling) on social network datasets def forward(self, g, h): # u , v = g.edges() # g.add_edges(v , u) # bidirect # g = g.add_self_loop() #add self-loops #global representation也从有向无环图变成无向有环图 h=self.glyco_embedding(h) # list of hidden representation at each layer (including the input layer) hidden_rep = [h] for i, layer in enumerate(self.ginlayers): h = layer(g, h) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) score_over_layer = 0 # perform graph sum pooling over all nodes in each layer for i, h in enumerate(hidden_rep): pooled_h = self.pool(g, h) score_over_layer += self.drop(self.linear_prediction[i](pooled_h)) return score_over_layer class GCN(nn.Module): def __init__(self, numberofglycos, hidden_dim, output_dim): super().__init__() self.ginlayers = nn.ModuleList() self.batch_norms = nn.ModuleList() self.glyco_embedding=nn.Embedding(numberofglycos, hidden_dim, padding_idx=None) #如果输入一直MLP的话,不同的糖会呈现线性关系,这里我们采用embedding # num_layers = 5 #也可以调整 # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme for layer in range(GNN_global_num_layers - 1): # excluding the input layer mlp = MLP(hidden_dim, hidden_dim, hidden_dim) self.ginlayers.append( GraphConv(in_feats=hidden_dim,out_feats=hidden_dim,allow_zero_in_degree=True) ) # set to True if learning epsilon self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # linear functions for graph sum poolings of output of each layer self.linear_prediction = nn.ModuleList() for layer in range(GNN_global_num_layers): self.linear_prediction.append(nn.Linear(hidden_dim, output_dim)) self.drop = nn.Dropout(0.5) self.pool = ( SumPooling() ) # change to mean readout (AvgPooling) on social network datasets def forward(self, g, h): #add self-loops #global representation也从有向无环图变成无向有环图 h=self.glyco_embedding(h) # list of hidden representation at each layer (including the input layer) hidden_rep = [h] for i, layer in enumerate(self.ginlayers): h = layer(g, h) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) score_over_layer = 0 # perform graph sum pooling over all nodes in each layer for i, h in enumerate(hidden_rep): # import ipdb # ipdb.set_trace() pooled_h = self.pool(g, h) score_over_layer += self.drop(self.linear_prediction[i](pooled_h)) return score_over_layer class GAT(nn.Module): def __init__(self, numberofglycos, hidden_dim, output_dim,num_heads): super().__init__() self.ginlayers = nn.ModuleList() self.batch_norms = nn.ModuleList() self.glyco_embedding=nn.Embedding(numberofglycos, hidden_dim, padding_idx=None) #如果输入一直MLP的话,不同的糖会呈现线性关系,这里我们采用embedding # num_layers = 5 #也可以调整 # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme for layer in range(GNN_global_num_layers - 1): # excluding the input layer mlp = MLP(hidden_dim, hidden_dim, hidden_dim) self.ginlayers.append( GATConv(in_feats=hidden_dim,out_feats=hidden_dim//num_heads,num_heads=num_heads,allow_zero_in_degree=True) ) # set to True if learning epsilon self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # linear functions for graph sum poolings of output of each layer self.linear_prediction = nn.ModuleList() for layer in range(GNN_global_num_layers): self.linear_prediction.append(nn.Linear(hidden_dim, output_dim)) self.drop = nn.Dropout(0.5) self.pool = ( SumPooling() ) # change to mean readout (AvgPooling) on social network datasets def forward(self, g, h): h=self.glyco_embedding(h) # list of hidden representation at each layer (including the input layer) hidden_rep = [h] for i, layer in enumerate(self.ginlayers): h = layer(g, h) h=h.reshape(-1,GNN_global_hidden_dim) h = self.batch_norms[i](h) h = F.relu(h) hidden_rep.append(h) score_over_layer = 0 # perform graph sum pooling over all nodes in each layer for i, h in enumerate(hidden_rep): pooled_h = self.pool(g, h) score_over_layer += self.drop(self.linear_prediction[i](pooled_h)) return score_over_layer if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--dataset", type=str, help="name of dataset", ) args = parser.parse_args() print(f"Training with DGL built-in GINConv module with a fixed epsilon") #有epsilon以就可以改变自身节点的权重 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load and split dataset dataset_train=torch.load("/remote-home/yxwang/test/zzb/DeepGlyco/model/20230127_test_model_validata") dataset=dataset_train['strct_graph'].values.tolist() import dgl import random #后面随机选择,包括batch内数目改一下 batchsize=2 sample=[i[0] for i in random.sample(dataset, batchsize)] train_loader = [dgl.batch(sample).to(device)] import ipdb # ipdb.set_trace() # create GIN model batched_graph = train_loader[0].to(device) feat = batched_graph.ndata.pop("attr") print("feat",feat) number_of_glycos=20 print(number_of_glycos) out_size = 768 hidden_size=16 model = GIN(number_of_glycos, hidden_size, out_size,init_eps=0).to(device) print("batchgraph",batched_graph) ipdb.set_trace() logits = model(batched_graph, feat) print(logits.size()) print("logits",logits)