EECE571F-project / learn2branch_gasse / reward_functions.py
reward_functions.py
Raw
import ecole

class DualBoundImprovement(ecole.reward.RewardFunction):
    def before_reset(self, model):
        self.prev_dual_bound = None

    def extract(self, model, done):
        current_dual_bound = model.as_pyscipopt().getDualbound()
        reward = 0
        if self.prev_dual_bound is not None:
            reward = current_dual_bound - self.prev_dual_bound
        self.prev_dual_bound = current_dual_bound
        return reward