EECE571F-project / learn2branch_gasse / observation_functions.py
observation_functions.py
Raw
import ecole
import ecole.observation
from ecole.typing import ObservationFunction
import numpy as np

class StrongBranch(ObservationFunction):
    def __init__(self) -> None:
        super().__init__()
        self.branching_function = ecole.observation.StrongBranchingScores()

    def before_reset(self, model: ecole.scip.Model) -> None:
        self.branching_function.before_reset(model)
    
    def extract(self, model: ecole.scip.Model, done: bool) -> ecole.typing.Observation:
        return self.branching_function.extract(model, done)
    
class HybridBranch(ObservationFunction):
    def __init__(self, eps) -> None:
        super().__init__()
        self.sb_function = ecole.observation.StrongBranchingScores()
        self.pb_function = ecole.observation.Pseudocosts()
        self.eps = eps # probability of using strong branch score
    
    def before_reset(self, model: ecole.scip.Model) -> None:
        self.sb_function.before_reset(model)
        self.pb_function.before_reset(model)

    def extract(self, model: ecole.scip.Model, done: bool):
        probs = [1- self.eps, self.eps]
        strong_branch = np.random.choice(np.arange(2), p=probs).astype(bool)
        if strong_branch:
            return (self.sb_function.extract(model, done), True)
        else:
            return (self.pb_function.extract(model, done), False)