import torch
import ecole as ec
import numpy as np
import os
import torch.distributions as D
from submissions.dual.model import GNNPolicy
class ObservationFunction(ec.observation.NodeBipartite):
def __init__(self, problem):
super().__init__()
def seed(self, seed):
pass
class Policy():
def __init__(self, problem, reward):
self.rng = np.random.RandomState()
print(os.getcwd())
# get parameters
if reward == 'primal':
params_path = f'submissions/dual/agents/trained_models/{problem}/actor-off_policy-pre_train-primal.pt'
print("loading primal")
elif reward == 'nodes':
params_path = f'submissions/dual/agents/trained_models/{problem}/actor-off_policy-pre_train-nodes.pt'
print("loading nodes")
else:
params_path = f'submissions/dual/agents/trained_models/{problem}/actor-off_policy-pre_train.pt'
# set up policy
# self.device = f"cuda:0"
self.device = torch.device('cpu')
self.policy = GNNPolicy().to(self.device)
self.policy.load_state_dict(torch.load(params_path, weights_only=True))
self.policy.eval()
def seed(self, seed):
self.rng = np.random.RandomState(seed)
def __call__(self, action_set, observation):
# mask variable features (no incumbent info)
variable_features = observation.variable_features
variable_features = np.delete(variable_features, 14, axis=1)
variable_features = np.delete(variable_features, 13, axis=1)
constraint_features = torch.FloatTensor(observation.row_features).to(self.device)
edge_index = torch.LongTensor(observation.edge_features.indices.astype(np.int64)).to(self.device)
edge_attr = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1)).to(self.device)
variable_features = torch.FloatTensor(variable_features).to(self.device)
action_set = torch.LongTensor(np.array(action_set, dtype=np.int64)).to(self.device)
with torch.no_grad():
logits = self.policy(constraint_features, edge_index, edge_attr, variable_features)
logits = logits[action_set]
action_idx = logits.argmax()
action = action_set[action_idx]
dist = D.Categorical(logits=logits)
log_prob = dist.log_prob(action_idx)
return action, log_prob.item()