import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
class BipartiteGraphConvolution(MessagePassing):
"""
The bipartite graph convolution is already provided by pytorch geometric and we merely need
to provide the exact form of the messages being passed.
"""
def __init__(self):
super().__init__("add")
emb_size = 64
self.feature_module_left = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size)
)
self.feature_module_edge = torch.nn.Sequential(
torch.nn.Linear(1, emb_size, bias=False)
)
self.feature_module_right = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size, bias=False)
)
self.feature_module_final = torch.nn.Sequential(
torch.nn.LayerNorm(emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
)
self.post_conv_module = torch.nn.Sequential(torch.nn.LayerNorm(emb_size))
# output_layers
self.output_module = torch.nn.Sequential(
torch.nn.Linear(2 * emb_size, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
)
def forward(self, left_features, edge_indices, edge_features, right_features):
"""
This method sends the messages, computed in the message method.
"""
output = self.propagate(
edge_indices,
size=(left_features.shape[0], right_features.shape[0]),
node_features=(left_features, right_features),
edge_features=edge_features,
)
return self.output_module(
torch.cat([self.post_conv_module(output), right_features], dim=-1)
)
def message(self, node_features_i, node_features_j, edge_features):
output = self.feature_module_final(
self.feature_module_left(node_features_i)
+ self.feature_module_edge(edge_features)
+ self.feature_module_right(node_features_j)
)
return output
class GNNActor(torch.nn.Module):
def __init__(self, n_layers=1):
super().__init__()
emb_size = 64
cons_nfeats = 5
edge_nfeats = 1
var_nfeats = 17
# CONSTRAINT EMBEDDING
self.cons_embedding = torch.nn.Sequential(
torch.nn.LayerNorm(cons_nfeats),
torch.nn.Linear(cons_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
# EDGE EMBEDDING
self.edge_embedding = torch.nn.Sequential(
torch.nn.LayerNorm(edge_nfeats),
)
# VARIABLE EMBEDDING
self.var_embedding = torch.nn.Sequential(
torch.nn.LayerNorm(var_nfeats),
torch.nn.Linear(var_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
self.conv_v_to_c = BipartiteGraphConvolution()
self.conv_c_to_v = BipartiteGraphConvolution()
self.output_module = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, 1, bias=False),
)
def forward(
self, constraint_features, edge_indices, edge_features, variable_features
):
reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
# First step: linear embedding layers to a common dimension (64)
constraint_features = self.cons_embedding(constraint_features)
edge_features = self.edge_embedding(edge_features)
variable_features = self.var_embedding(variable_features)
constraint_features = self.conv_v_to_c(
variable_features, reversed_edge_indices, edge_features, constraint_features
)
variable_features = self.conv_c_to_v(
constraint_features, edge_indices, edge_features, variable_features
)
# A final MLP on the variable features
output = self.output_module(variable_features).squeeze(-1)
return output
class GNNCritic(torch.nn.Module):
def __init__(self, n_layers=1):
super().__init__()
emb_size = 64
cons_nfeats = 5
edge_nfeats = 1
var_nfeats = 17
# CONSTRAINT EMBEDDING
self.cons_embedding = torch.nn.Sequential(
torch.nn.LayerNorm(cons_nfeats),
torch.nn.Linear(cons_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
# EDGE EMBEDDING
self.edge_embedding = torch.nn.Sequential(
torch.nn.LayerNorm(edge_nfeats),
)
# VARIABLE EMBEDDING
self.var_embedding = torch.nn.Sequential(
torch.nn.LayerNorm(var_nfeats),
torch.nn.Linear(var_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
self.conv_v_to_c = BipartiteGraphConvolution()
self.conv_c_to_v = BipartiteGraphConvolution()
self.value_head = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, 1, bias=True),
)
def forward(
self, constraint_features, edge_indices, edge_features, variable_features, batch_indices_c, batch_indices_v, candidates
):
reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
# First step: linear embedding layers to a common dimension (64)
constraint_features = self.cons_embedding(constraint_features)
edge_features = self.edge_embedding(edge_features)
variable_features = self.var_embedding(variable_features)
constraint_features = self.conv_v_to_c(
variable_features, reversed_edge_indices, edge_features, constraint_features
)
variable_features = self.conv_c_to_v(
constraint_features, edge_indices, edge_features, variable_features
)
# A final MLP on the variable features
# Graph-Level Readout (e.g., global mean pooling)
# batch_indices should be provided to indicate which nodes belong to which graph in the batch
batch_indices_v = batch_indices_v.to('cpu')
variable_features_filtered = variable_features[candidates]
batch_indices_v_filtered = batch_indices_v[candidates]
graph_representation_v = global_mean_pool(variable_features, batch=batch_indices_v) # only want to pool candidate indices!!
# Value Prediction
value = self.value_head(graph_representation_v).squeeze(-1) # Output shape: [num_graphs]
return value