import ecole as ec import numpy as np import torch from torchrl.modules import MaskedCategorical from torch.distributions import Categorical from .model import GNNActor from .observations import SearchTree class ObservationFunction(): def __init__(self, problem): # called once for each problem benchmark self.problem = problem # to devise problem-specific observations # self.search_tree = None # def seed(self, seed): # # called before each episode # # use this seed to make your code deterministic # self.search_tree.manual_seed(seed) def before_reset(self, model): # called when a new episode is about to start self.search_tree = SearchTree(model) def extract(self, model, done): if done: return None if self.search_tree: self.search_tree.update_tree(model) return self.search_tree class Policy(): def __init__(self, problem): # called once for each problem benchmark self.rng = np.random.RandomState() self.problem = problem # to devise problem-specific policies self.actor = GNNActor() def seed(self, seed): # called before each episode # use this seed to make your code deterministic self.rng = np.random.RandomState(seed) def load_weights(self, path_to_weights): self.actor.load_state_dict(torch.load(path_to_weights, weights_only=True)) self.actor.eval() def __call__(self, action_set, observation): variable_features = observation.variable_features variable_features = np.delete(variable_features, 14, axis=1) variable_features = np.delete(variable_features, 13, axis=1) observation_tuple = ( torch.from_numpy(observation.row_features.astype(np.float32)), torch.from_numpy(observation.edge_features.indices.astype(np.int64)), torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1), torch.from_numpy(variable_features.astype(np.float32)), ) action_set_tensor = torch.tensor(action_set, dtype=torch.int32) # pass state to actor and get distribution and create valid action mask logits = self.actor(*observation_tuple) mask = torch.zeros_like(logits, dtype=torch.bool) mask[action_set] = True if torch.all(mask == True): action_dist = Categorical(logits=logits) else: action_dist = MaskedCategorical(logits=logits, mask=mask) # sample from action distribution action = action_dist.sample().item() return action