import ecole
import ecole.observation
from ecole.typing import ObservationFunction
import numpy as np
class StrongBranch(ObservationFunction):
def __init__(self) -> None:
super().__init__()
self.branching_function = ecole.observation.StrongBranchingScores()
def before_reset(self, model: ecole.scip.Model) -> None:
self.branching_function.before_reset(model)
def extract(self, model: ecole.scip.Model, done: bool) -> ecole.typing.Observation:
return self.branching_function.extract(model, done)
class HybridBranch(ObservationFunction):
def __init__(self, eps) -> None:
super().__init__()
self.sb_function = ecole.observation.StrongBranchingScores()
self.pb_function = ecole.observation.Pseudocosts()
self.eps = eps # probability of using strong branch score
def before_reset(self, model: ecole.scip.Model) -> None:
self.sb_function.before_reset(model)
self.pb_function.before_reset(model)
def extract(self, model: ecole.scip.Model, done: bool):
probs = [1- self.eps, self.eps]
strong_branch = np.random.choice(np.arange(2), p=probs).astype(bool)
if strong_branch:
return (self.sb_function.extract(model, done), True)
else:
return (self.pb_function.extract(model, done), False)