EECE571F-project / learn2branch_gasse / eval.py
eval.py
Raw
import ecole.environment
import ecole.observation
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import random
import numpy as np
import ecole
from model import GNNActor
import os
from torchrl.modules import MaskedCategorical

def to_tensor(obs, device):
    obs = (
        torch.from_numpy(obs.row_features.astype(np.float32)).to(device),
        torch.from_numpy(obs.edge_features.indices.astype(np.int64)).to(device),
        torch.from_numpy(obs.edge_features.values.astype(np.float32)).view(-1, 1).to(device),
        torch.from_numpy(obs.variable_features.astype(np.float32)).to(device),
    )
    return obs


def evaluate(cfg: DictConfig, actor, gen, env, benchmark_env):
    actor.eval()
    for testing_instance_i in range(cfg.testing.n_instances):
        instance = next(gen)
        n_nodes, time = 0, 0 # reset statistics
        observation, action_set, _, done, info = env.reset(instance)

        while not done:
            with torch.no_grad():
                observation = to_tensor(observation, cfg.device)
                logits = actor(*observation)
                if cfg.testing.greedy:
                    action = action_set[logits[action_set.astype(np.int64)].argmax()]
                else:
                    mask = torch.zeros_like(logits, dtype=torch.bool)
                    mask[action_set] = True
                    action_dist = MaskedCategorical(logits, mask=mask)
                    action = action_dist.sample()
                observation, action_set, _, done, info = env.step(action.item())
        
        n_nodes = info['n_nodes'] 
        time = info['time']
        primal = info['primal']
        dual = info['dual']
        lp_iters = info['lp_iters']

        # solve using benchmark
        benchmark_env.reset(instance)
        _, _, _, _, benchmark_info = benchmark_env.step({})

        print(f"Problem Instance: {int(testing_instance_i)}")
        print(f"SCIP dual integral: {benchmark_info['dual']:.4e}, SCIP primal integral: {benchmark_info['primal']:.4e}, SCIP iterations: {int(benchmark_info['lp_iters']):>5d}, SCIP nodes: {int(benchmark_info['n_nodes']):>5d}, SCIP time: {benchmark_info['time']:>7.2f}")
        print(f"GNN dual integral: {dual:.5e}, GNN primal integral: {primal:.5e}, GNN iterations: {int(lp_iters):>6d}, GNN nodes: {int(n_nodes):>6d}, GNN time: {time:>7.2f}")
       
        PG_nodes = (benchmark_info['n_nodes'] - n_nodes)/benchmark_info['n_nodes']
        PG_time = (benchmark_info['time'] - time)/benchmark_info['time']
        PG_dual = (benchmark_info['dual'] - dual)/benchmark_info['dual']
        PG_primal = (benchmark_info['primal'] - primal)/benchmark_info['primal']
        PG_iters = (benchmark_info['lp_iters'] - lp_iters)/benchmark_info['lp_iters']
        
        print(f"percent gain: {PG_dual:>16.2%}, percent gain: {PG_primal:>16.2%}, percent gain: {PG_iters:>12.2%}, percent gain: {PG_nodes:>12.2%}, percent gain: {PG_time:>12.2%}")





@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    # For reproducibility
    torch.manual_seed(cfg.testing.seed)
    random.seed(cfg.testing.seed)
    np.random.seed(cfg.testing.seed)
    generator = torch.Generator()
    generator.manual_seed(cfg.testing.seed)

    # define problem generator & difficulty level
    if cfg.problem == "set_cover": 
        if cfg.testing.mode == "easy":
            n_rows = cfg.testing.easy.n_rows
            n_cols = cfg.testing.easy.n_cols
        elif cfg.testing.mode == "medium":
            n_rows = cfg.testing.medium.n_rows
            n_cols = cfg.testing.medium.n_cols
        elif cfg.testing.mode == "hard":
            n_rows = cfg.testing.hard.n_rows
            n_cols = cfg.testing.hard.n_cols
        else:
            raise ValueError("Training mode must be one of: easy, medium or hard")               
        
        gen = ecole.instance.SetCoverGenerator(
                n_rows=n_rows,
                n_cols=n_cols,
                density=cfg.density,
        )
     

    gen.seed(cfg.testing.seed)
    
    # define the observation function
    observation_functions = (
        ecole.observation.NodeBipartite()
    )

    # scip parameters used in paper
    scip_parameters = {
        "separating/maxrounds": 0,
        "presolving/maxrestarts": 0,
        "limits/time": 900,
    }

    # define information function
    information_function = {
        "n_nodes": ecole.reward.NNodes().cumsum(), # total number of nodes processed since the previous state.
        "time": ecole.reward.SolvingTime().cumsum(),  # number of seconds spent solving the instance since the previous state.
        "lp_iters": ecole.reward.LpIterations().cumsum(),
        "primal": ecole.reward.PrimalIntegral().cumsum(),
        "dual": ecole.reward.DualIntegral().cumsum(),
    }

    # define the environment
    env = ecole.environment.Branching(
        observation_function=observation_functions,
        information_function=information_function,
        scip_params=scip_parameters
    )

    env.seed(cfg.testing.seed) # seed environment
    
    # Benchmark solver (SCIP)
    benchmark_env = ecole.environment.Configuring(
        observation_function=None,
        reward_function=None,
        information_function={
            "n_nodes": ecole.reward.NNodes().cumsum(), # total number of nodes processed since the previous state.
            "time": ecole.reward.SolvingTime().cumsum(),  # number of seconds spent solving the instance since the previous state.
            "lp_iters": ecole.reward.LpIterations().cumsum(),
            "primal": ecole.reward.PrimalIntegral().cumsum(),
            "dual": ecole.reward.DualIntegral().cumsum(),
       },
        scip_params=scip_parameters
    )

    benchmark_env.seed(cfg.testing.seed)


    actor = GNNActor()
    
    # Load trained model
    print(os.getcwd())
    actor.load_state_dict(torch.load(os.path.join(cfg.model_path), map_location=cfg.device, weights_only=True))
    actor.eval()
    print("Loaded model at ", cfg.model_path)

    evaluate(cfg, actor, gen, env, benchmark_env)


    

if __name__ == "__main__":
    main()