import torch import ecole as ec import numpy as np import os import torch.distributions as D class Policy(): def __init__(self, problem): self.rng = np.random.RandomState() self.problem = problem def seed(self, seed): self.rng = np.random.RandomState(seed) def __call__(self, action_set, observation): # mask variable features (no incumbent info) strong_branch_scores = observation[-1] strong_branch_scores = strong_branch_scores[action_set] action_idx = strong_branch_scores.argmax() action = action_set[action_idx] dist = D.Categorical(logits=torch.tensor(strong_branch_scores, dtype=torch.float32)) log_prob = dist.log_prob(torch.tensor(action_idx, dtype=torch.long)) return action, log_prob.item()