EECE571F-project / reward_study / main.py
main.py
Raw
import ecole.environment
import torch
import random
import numpy as np
import ecole
import os
import matplotlib.pyplot as plt


def main():
    seed = 0
    # For reproducibility
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    generator = torch.Generator()
    generator.manual_seed(seed)

    n_rows = 500
    n_cols = 1000
    density = 0.1
    n_instances = 1

    # Define problem instance
    gen = ecole.instance.SetCoverGenerator(
        n_rows=n_rows,
        n_cols=n_cols,
        density=density,
    )
    gen.seed(seed)

    # SCIP parameters used in paper
    scip_parameters = {
        "separating/maxrounds": 0,
        "presolving/maxrestarts": 0,
        "limits/time": 3600.0,  # Time limit in seconds
    }

    # Define information function
    information_function = {
        "n_nodes": ecole.reward.NNodes(),  # Total number of nodes processed since the previous state
        "time": ecole.reward.SolvingTime(),  # Number of seconds spent solving the instance since the previous state
        "primal_int": ecole.reward.PrimalIntegral(),
        "dual_int": ecole.reward.DualIntegral(),
        "primal_dual_int": ecole.reward.PrimalDualIntegral()
    }

    # SCIP environments
    sb_env = ecole.environment.Branching(
        observation_function=ecole.observation.StrongBranchingScores(),
        reward_function=None,
        information_function=information_function,
        scip_params=scip_parameters
    )
    sb_env.seed(seed)

    pb_env = ecole.environment.Branching(
        observation_function=ecole.observation.Pseudocosts(),
        reward_function=None,
        information_function=information_function,
        scip_params=scip_parameters
    )
    pb_env.seed(seed)

    def record(instance, env):
        # Solve using benchmark
        observation, action_set, _, done, info = env.reset(instance)
        dual_int_arr = []
        nodes = []
        time = []

        while not done:
            action = action_set[observation[action_set].argmax()]
            observation, action_set, _, done, info = env.step(action)
            dual_int_arr.append(info['dual_int'])
            nodes.append(info['n_nodes'])
            time.append(info['time'])

        # Get the optimal or best-found solution
        scip_model = env.model.as_pyscipopt()
        status = scip_model.getStatus()
        flag = False
        if status in ["optimal"]:
            flag = True
            obj_value = scip_model.getObjVal()  # Optimal solution
        else:
            obj_value = scip_model.getPrimalbound()  # Best solution found within time limit

        return dual_int_arr, obj_value, nodes, time, flag

    # Record results
    for instance_i in range(n_instances):
        instance = next(gen)
        sb_dual, sb_obj, sb_nodes, sb_time, sb_is_optimal = record(instance, sb_env)
        pb_dual, pb_obj, pb_nodes, pb_time, pb_is_optimal = record(instance, pb_env)

        sb_time = np.cumsum(sb_time)
        pb_time = np.cumsum(pb_time)
        print(f"Strong Branching - Time: {sb_time[-1]:.2f}s, Objective: {sb_obj}, Optimal: {sb_is_optimal}")
        print(f"Pseudocost Branching - Time: {pb_time[-1]:.2f}s, Objective: {pb_obj}, Optimal: {pb_is_optimal}")

    fig, axis = plt.subplots(3, 1)  # Larger figure for better readability

    # Plot for reward dual integral
    axis[0].plot(sb_dual, marker='^', linestyle='-', label='SB reward dual integral', alpha=0.6)
    axis[0].plot(pb_dual, marker='o', linestyle='--', label='PB reward dual integral', alpha=0.7)
    axis[0].set_title('Dual Integral')
    axis[0].legend()
    axis[0].grid(True)

    # Plot for nodes processed
    axis[1].plot(np.cumsum(sb_nodes), marker='^', linestyle='-', label='SB nodes', alpha=0.6)
    axis[1].plot(np.cumsum(pb_nodes), marker='o', linestyle='--', label='PB nodes', alpha=0.7)
    axis[1].set_title('Nodes Processed')
    axis[1].legend()
    axis[1].grid(True)

    # Plot for solving time
    axis[2].plot(sb_time, marker='^', linestyle='-', label='SB time', alpha=0.6)
    axis[2].plot(pb_time, marker='o', linestyle='--', label='PB time', alpha=0.7)
    axis[2].set_title('Solving Time')
    axis[2].legend()
    axis[2].grid(True)

    # Adjust layout and display
    plt.tight_layout()
    plt.show()



    return


if __name__ == "__main__":
    main()