EECE571F-project / ml4co-competition / submissions / dual / agents / rl.py
rl.py
Raw
import torch
import ecole as ec
import numpy as np
import os
import torch.distributions as D
from submissions.dual.model import GNNPolicy


class ObservationFunction(ec.observation.NodeBipartite):

    def __init__(self, problem):
        super().__init__()

    def seed(self, seed):
        pass


class Policy():

    def __init__(self, problem, reward):
        self.rng = np.random.RandomState()

        print(os.getcwd())

        # get parameters
        if reward == 'primal':
            params_path = f'submissions/dual/agents/trained_models/{problem}/actor-off_policy-pre_train-primal.pt'
            print("loading primal")
        elif reward == 'nodes':
            params_path = f'submissions/dual/agents/trained_models/{problem}/actor-off_policy-pre_train-nodes.pt'
            print("loading nodes")
        else:
            params_path = f'submissions/dual/agents/trained_models/{problem}/actor-off_policy-pre_train.pt'

        # set up policy
        # self.device = f"cuda:0"
        self.device = torch.device('cpu')
        self.policy = GNNPolicy().to(self.device)
        self.policy.load_state_dict(torch.load(params_path, weights_only=True))
        self.policy.eval()

    def seed(self, seed):
        self.rng = np.random.RandomState(seed)

    def __call__(self, action_set, observation):
        # mask variable features (no incumbent info)
        variable_features = observation.variable_features
        variable_features = np.delete(variable_features, 14, axis=1)
        variable_features = np.delete(variable_features, 13, axis=1)

        constraint_features = torch.FloatTensor(observation.row_features).to(self.device)
        edge_index = torch.LongTensor(observation.edge_features.indices.astype(np.int64)).to(self.device)
        edge_attr = torch.FloatTensor(np.expand_dims(observation.edge_features.values, axis=-1)).to(self.device)
        variable_features = torch.FloatTensor(variable_features).to(self.device)
        action_set = torch.LongTensor(np.array(action_set, dtype=np.int64)).to(self.device)

        with torch.no_grad():
            logits = self.policy(constraint_features, edge_index, edge_attr, variable_features)
            logits = logits[action_set]
            action_idx = logits.argmax()
            action = action_set[action_idx]

            dist = D.Categorical(logits=logits)
            log_prob = dist.log_prob(action_idx)

        return action, log_prob.item()