import torch import ecole as ec import numpy as np from model import GNNPolicy class ObservationFunction(ec.observation.NodeBipartite): def __init__(self, problem): super().__init__() def seed(self, seed): pass class Policy(): def __init__(self, problem): self.rng = np.random.RandomState() # get parameters params_path = f'agents/trained_models/{problem}/best_params.pkl' # set up policy self.device = f"cuda:0" self.policy = GNNPolicy().to(self.device) self.policy.load_state_dict(torch.load(params_path)) def seed(self, seed): self.rng = np.random.RandomState(seed) def __call__(self, action_set, observation): # mask variable features (no incumbent info) variable_features = observation.column_features variable_features = np.delete(variable_features, 14, axis=1) variable_features = np.delete(variable_features, 13, axis=1) constraint_features = torch.FloatTensor(observation.row_features).to(self.device) edge_index = torch.LongTensor(observation.edge_features.indices.astype(np.int64)).to(self.device) edge_attr = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1)).to(self.device) variable_features = torch.FloatTensor(variable_features).to(self.device) action_set = torch.LongTensor(np.array(action_set, dtype=np.int64)).to(self.device) logits = self.policy(constraint_features, edge_index, edge_attr, variable_features) logits = logits[action_set] action_idx = logits.argmax().item() action = action_set[action_idx] return action