import torch
import ecole as ec
import numpy as np
import os
import torch.distributions as D
class Policy():
def __init__(self, problem):
self.rng = np.random.RandomState()
self.problem = problem
def seed(self, seed):
self.rng = np.random.RandomState(seed)
def __call__(self, action_set, observation):
# mask variable features (no incumbent info)
strong_branch_scores = observation[-1]
strong_branch_scores = strong_branch_scores[action_set]
action_idx = strong_branch_scores.argmax()
action = action_set[action_idx]
dist = D.Categorical(logits=torch.tensor(strong_branch_scores, dtype=torch.float32))
log_prob = dist.log_prob(torch.tensor(action_idx, dtype=torch.long))
return action, log_prob.item()