import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing, global_mean_pool class BipartiteGraphConvolution(MessagePassing): """ The bipartite graph convolution is already provided by pytorch geometric and we merely need to provide the exact form of the messages being passed. """ def __init__(self): super().__init__("add") emb_size = 64 self.feature_module_left = torch.nn.Sequential( torch.nn.Linear(emb_size, emb_size) ) self.feature_module_edge = torch.nn.Sequential( torch.nn.Linear(1, emb_size, bias=False) ) self.feature_module_right = torch.nn.Sequential( torch.nn.Linear(emb_size, emb_size, bias=False) ) self.feature_module_final = torch.nn.Sequential( torch.nn.LayerNorm(emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, emb_size), ) self.post_conv_module = torch.nn.Sequential(torch.nn.LayerNorm(emb_size)) # output_layers self.output_module = torch.nn.Sequential( torch.nn.Linear(2 * emb_size, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, emb_size), ) def forward(self, left_features, edge_indices, edge_features, right_features): """ This method sends the messages, computed in the message method. """ output = self.propagate( edge_indices, size=(left_features.shape[0], right_features.shape[0]), node_features=(left_features, right_features), edge_features=edge_features, ) return self.output_module( torch.cat([self.post_conv_module(output), right_features], dim=-1) ) def message(self, node_features_i, node_features_j, edge_features): output = self.feature_module_final( self.feature_module_left(node_features_i) + self.feature_module_edge(edge_features) + self.feature_module_right(node_features_j) ) return output class GNNActor(torch.nn.Module): def __init__(self, n_layers=1): super().__init__() emb_size = 64 cons_nfeats = 5 edge_nfeats = 1 var_nfeats = 19 # CONSTRAINT EMBEDDING self.cons_embedding = torch.nn.Sequential( torch.nn.LayerNorm(cons_nfeats), torch.nn.Linear(cons_nfeats, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, emb_size), torch.nn.ReLU(), ) # EDGE EMBEDDING self.edge_embedding = torch.nn.Sequential( torch.nn.LayerNorm(edge_nfeats), ) # VARIABLE EMBEDDING self.var_embedding = torch.nn.Sequential( torch.nn.LayerNorm(var_nfeats), torch.nn.Linear(var_nfeats, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, emb_size), torch.nn.ReLU(), ) # Add multiple BipartiteGraphConvolution layers # self.conv_layers_v_to_c = torch.nn.ModuleList( # [BipartiteGraphConvolution() for _ in range(n_layers)] # ) # self.conv_layers_c_to_v = torch.nn.ModuleList( # [BipartiteGraphConvolution() for _ in range(n_layers)] # ) self.conv_v_to_c = BipartiteGraphConvolution() self.conv_c_to_v = BipartiteGraphConvolution() self.output_module = torch.nn.Sequential( torch.nn.Linear(emb_size, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, 1, bias=False), ) def forward( self, constraint_features, edge_indices, edge_features, variable_features ): reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0) # First step: linear embedding layers to a common dimension (64) constraint_features = self.cons_embedding(constraint_features) edge_features = self.edge_embedding(edge_features) variable_features = self.var_embedding(variable_features) # for conv_v_to_c, conv_c_to_v in zip(self.conv_layers_v_to_c, self.conv_layers_c_to_v): # constraint_features = conv_v_to_c( # variable_features, reversed_edge_indices, edge_features, constraint_features # ) # variable_features = conv_c_to_v( # constraint_features, edge_indices, edge_features, variable_features # ) constraint_features = self.conv_v_to_c( variable_features, reversed_edge_indices, edge_features, constraint_features ) variable_features = self.conv_c_to_v( constraint_features, edge_indices, edge_features, variable_features ) # A final MLP on the variable features output = self.output_module(variable_features).squeeze(-1) return output class GNNCritic(torch.nn.Module): def __init__(self, n_layers=1): super().__init__() emb_size = 64 cons_nfeats = 5 edge_nfeats = 1 var_nfeats = 19 # CONSTRAINT EMBEDDING self.cons_embedding = torch.nn.Sequential( torch.nn.LayerNorm(cons_nfeats), torch.nn.Linear(cons_nfeats, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, emb_size), torch.nn.ReLU(), ) # EDGE EMBEDDING self.edge_embedding = torch.nn.Sequential( torch.nn.LayerNorm(edge_nfeats), ) # VARIABLE EMBEDDING self.var_embedding = torch.nn.Sequential( torch.nn.LayerNorm(var_nfeats), torch.nn.Linear(var_nfeats, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, emb_size), torch.nn.ReLU(), ) # Add multiple BipartiteGraphConvolution layers # self.conv_layers_v_to_c = torch.nn.ModuleList( # [BipartiteGraphConvolution() for _ in range(n_layers)] # ) # self.conv_layers_c_to_v = torch.nn.ModuleList( # [BipartiteGraphConvolution() for _ in range(n_layers)] # ) self.conv_v_to_c = BipartiteGraphConvolution() self.conv_c_to_v = BipartiteGraphConvolution() self.value_head = torch.nn.Sequential( torch.nn.Linear(emb_size, emb_size), torch.nn.ReLU(), torch.nn.Linear(emb_size, 1, bias=False), ) def forward( self, constraint_features, edge_indices, edge_features, variable_features, batch_indices_c, batch_indices_v, candidates ): reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0) # First step: linear embedding layers to a common dimension (64) constraint_features = self.cons_embedding(constraint_features) edge_features = self.edge_embedding(edge_features) variable_features = self.var_embedding(variable_features) # for conv_v_to_c, conv_c_to_v in zip(self.conv_layers_v_to_c, self.conv_layers_c_to_v): # constraint_features = conv_v_to_c( # variable_features, reversed_edge_indices, edge_features, constraint_features # ) # variable_features = conv_c_to_v( # constraint_features, edge_indices, edge_features, variable_features # ) constraint_features = self.conv_v_to_c( variable_features, reversed_edge_indices, edge_features, constraint_features ) variable_features = self.conv_c_to_v( constraint_features, edge_indices, edge_features, variable_features ) # A final MLP on the variable features # Graph-Level Readout (e.g., global mean pooling) # batch_indices should be provided to indicate which nodes belong to which graph in the batch variable_features_filtered = variable_features[candidates] batch_indices_v_filtered = batch_indices_v[candidates] graph_representation_v = global_mean_pool(variable_features, batch=batch_indices_v) # only want to pool candidate indices!! # Value Prediction value = self.value_head(graph_representation_v).squeeze(-1) # Output shape: [num_graphs] return value