EECE571F-project / ml4co-competition / submissions / dual / agents / strong_branch.py
strong_branch.py
Raw
import torch
import ecole as ec
import numpy as np
import os
import torch.distributions as D


class Policy():

    def __init__(self, problem):
        self.rng = np.random.RandomState()
        self.problem = problem

    def seed(self, seed):
        self.rng = np.random.RandomState(seed)

    def __call__(self, action_set, observation):
        # mask variable features (no incumbent info)
        strong_branch_scores = observation[-1]
        strong_branch_scores = strong_branch_scores[action_set]
        action_idx = strong_branch_scores.argmax()
        action = action_set[action_idx]

        dist = D.Categorical(logits=torch.tensor(strong_branch_scores, dtype=torch.float32))
        log_prob = dist.log_prob(torch.tensor(action_idx, dtype=torch.long))
        return action, log_prob.item()