import ecole as ec
import numpy as np
import torch
from torchrl.modules import MaskedCategorical
from torch.distributions import Categorical
from .model import GNNActor
from .observations import SearchTree
class ObservationFunction():
def __init__(self, problem):
# called once for each problem benchmark
self.problem = problem # to devise problem-specific observations
# self.search_tree = None
# def seed(self, seed):
# # called before each episode
# # use this seed to make your code deterministic
# self.search_tree.manual_seed(seed)
def before_reset(self, model):
# called when a new episode is about to start
self.search_tree = SearchTree(model)
def extract(self, model, done):
if done:
return None
if self.search_tree:
self.search_tree.update_tree(model)
return self.search_tree
class Policy():
def __init__(self, problem):
# called once for each problem benchmark
self.rng = np.random.RandomState()
self.problem = problem # to devise problem-specific policies
self.actor = GNNActor()
def seed(self, seed):
# called before each episode
# use this seed to make your code deterministic
self.rng = np.random.RandomState(seed)
def load_weights(self, path_to_weights):
self.actor.load_state_dict(torch.load(path_to_weights, weights_only=True))
self.actor.eval()
def __call__(self, action_set, observation):
variable_features = observation.variable_features
variable_features = np.delete(variable_features, 14, axis=1)
variable_features = np.delete(variable_features, 13, axis=1)
observation_tuple = (
torch.from_numpy(observation.row_features.astype(np.float32)),
torch.from_numpy(observation.edge_features.indices.astype(np.int64)),
torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1),
torch.from_numpy(variable_features.astype(np.float32)),
)
action_set_tensor = torch.tensor(action_set, dtype=torch.int32)
# pass state to actor and get distribution and create valid action mask
logits = self.actor(*observation_tuple)
mask = torch.zeros_like(logits, dtype=torch.bool)
mask[action_set] = True
if torch.all(mask == True):
action_dist = Categorical(logits=logits)
else:
action_dist = MaskedCategorical(logits=logits, mask=mask)
# sample from action distribution
action = action_dist.sample().item()
return action