import torch
import torch.nn.functional as F
import torch_geometric
import numpy as np
class PreNormException(Exception):
pass
class PreNormLayer(torch.nn.Module):
def __init__(self, n_units, shift=True, scale=True, name=None):
super().__init__()
assert shift or scale
self.register_buffer('shift', torch.zeros(n_units) if shift else None)
self.register_buffer('scale', torch.ones(n_units) if scale else None)
self.n_units = n_units
self.waiting_updates = False
self.received_updates = False
def forward(self, input_):
if self.waiting_updates:
self.update_stats(input_)
self.received_updates = True
raise PreNormException
if self.shift is not None:
input_ = input_ + self.shift
if self.scale is not None:
input_ = input_ * self.scale
return input_
def start_updates(self):
self.avg = 0
self.var = 0
self.m2 = 0
self.count = 0
self.waiting_updates = True
self.received_updates = False
def update_stats(self, input_):
"""
Online mean and variance estimation. See: Chan et al. (1979) Updating
Formulae and a Pairwise Algorithm for Computing Sample Variances.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
"""
assert self.n_units == 1 or input_.shape[-1] == self.n_units, f"Expected input dimension of size {self.n_units}, got {input_.shape[-1]}."
input_ = input_.reshape(-1, self.n_units)
sample_avg = input_.mean(dim=0)
sample_var = (input_ - sample_avg).pow(2).mean(dim=0)
sample_count = np.prod(input_.size())/self.n_units
delta = sample_avg - self.avg
self.m2 = self.var * self.count + sample_var * sample_count + delta ** 2 * self.count * sample_count / (
self.count + sample_count)
self.count += sample_count
self.avg += delta * sample_count / self.count
self.var = self.m2 / self.count if self.count > 0 else 1
def stop_updates(self):
"""
Ends pre-training for that layer, and fixes the layers's parameters.
"""
assert self.count > 0
if self.shift is not None:
self.shift = -self.avg
if self.scale is not None:
self.var[self.var < 1e-8] = 1
self.scale = 1 / torch.sqrt(self.var)
del self.avg, self.var, self.m2, self.count
self.waiting_updates = False
self.trainable = False
class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
def __init__(self):
super().__init__('add')
emb_size = 64
self.feature_module_left = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size)
)
self.feature_module_edge = torch.nn.Sequential(
torch.nn.Linear(1, emb_size, bias=False)
)
self.feature_module_right = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size, bias=False)
)
self.feature_module_final = torch.nn.Sequential(
PreNormLayer(1, shift=False),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size)
)
self.post_conv_module = torch.nn.Sequential(
PreNormLayer(1, shift=False)
)
# output_layers
self.output_module = torch.nn.Sequential(
torch.nn.Linear(2*emb_size, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
)
def forward(self, left_features, edge_indices, edge_features, right_features):
output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
node_features=(left_features, right_features), edge_features=edge_features)
return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))
def message(self, node_features_i, node_features_j, edge_features):
output = self.feature_module_final(self.feature_module_left(node_features_i)
+ self.feature_module_edge(edge_features)
+ self.feature_module_right(node_features_j))
return output
class BaseModel(torch.nn.Module):
"""
Our base model class, which implements pre-training methods.
"""
def pre_train_init(self):
for module in self.modules():
if isinstance(module, PreNormLayer):
module.start_updates()
def pre_train_next(self):
for module in self.modules():
if isinstance(module, PreNormLayer) and module.waiting_updates and module.received_updates:
module.stop_updates()
return module
return None
def pre_train(self, *args, **kwargs):
try:
with torch.no_grad():
self.forward(*args, **kwargs)
return False
except PreNormException:
return True
class GNNPolicy(BaseModel):
def __init__(self):
super().__init__()
emb_size = 64
cons_nfeats = 5
edge_nfeats = 1
var_nfeats = 17
# CONSTRAINT EMBEDDING
self.cons_embedding = torch.nn.Sequential(
PreNormLayer(cons_nfeats),
torch.nn.Linear(cons_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
# EDGE EMBEDDING
self.edge_embedding = torch.nn.Sequential(
PreNormLayer(edge_nfeats),
)
# VARIABLE EMBEDDING
self.var_embedding = torch.nn.Sequential(
PreNormLayer(var_nfeats),
torch.nn.Linear(var_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
self.conv_v_to_c = BipartiteGraphConvolution()
self.conv_c_to_v = BipartiteGraphConvolution()
self.output_module = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, 1, bias=False),
)
def forward(self, constraint_features, edge_indices, edge_features, variable_features):
reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
constraint_features = self.cons_embedding(constraint_features)
edge_features = self.edge_embedding(edge_features)
variable_features = self.var_embedding(variable_features)
constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features, constraint_features)
variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)
output = self.output_module(variable_features).squeeze(-1)
return output