import ecole import ecole.observation from ecole.typing import ObservationFunction import numpy as np class StrongBranch(ObservationFunction): def __init__(self) -> None: super().__init__() self.branching_function = ecole.observation.StrongBranchingScores() def before_reset(self, model: ecole.scip.Model) -> None: self.branching_function.before_reset(model) def extract(self, model: ecole.scip.Model, done: bool) -> ecole.typing.Observation: return self.branching_function.extract(model, done) class HybridBranch(ObservationFunction): def __init__(self, eps) -> None: super().__init__() self.sb_function = ecole.observation.StrongBranchingScores() self.pb_function = ecole.observation.Pseudocosts() self.eps = eps # probability of using strong branch score def before_reset(self, model: ecole.scip.Model) -> None: self.sb_function.before_reset(model) self.pb_function.before_reset(model) def extract(self, model: ecole.scip.Model, done: bool): probs = [1- self.eps, self.eps] strong_branch = np.random.choice(np.arange(2), p=probs).astype(bool) if strong_branch: return (self.sb_function.extract(model, done), True) else: return (self.pb_function.extract(model, done), False)