EECE571F-project / learn2branch_self / agents / dual.py
dual.py
Raw
import ecole as ec
import numpy as np
import torch
from torchrl.modules import MaskedCategorical
from torch.distributions import Categorical
from .model import GNNActor
from .observations import SearchTree

class ObservationFunction():

    def __init__(self, problem):
        # called once for each problem benchmark
        self.problem = problem  # to devise problem-specific observations
        # self.search_tree = None

    # def seed(self, seed):
    #     # called before each episode
    #     # use this seed to make your code deterministic
    #     self.search_tree.manual_seed(seed)

    def before_reset(self, model):
        # called when a new episode is about to start
        self.search_tree = SearchTree(model)

    def extract(self, model, done):
        if done:
            return None
        
        if self.search_tree:
            self.search_tree.update_tree(model)
        return self.search_tree


class Policy():

    def __init__(self, problem):
        # called once for each problem benchmark
        self.rng = np.random.RandomState()
        self.problem = problem  # to devise problem-specific policies
        self.actor = GNNActor()

    def seed(self, seed):
        # called before each episode
        # use this seed to make your code deterministic
        self.rng = np.random.RandomState(seed)
    
    def load_weights(self, path_to_weights):
        self.actor.load_state_dict(torch.load(path_to_weights, weights_only=True))
        self.actor.eval()

    def __call__(self, action_set, observation):

        variable_features = observation.variable_features
        variable_features = np.delete(variable_features, 14, axis=1)
        variable_features = np.delete(variable_features, 13, axis=1)

        observation_tuple = (
            torch.from_numpy(observation.row_features.astype(np.float32)),
            torch.from_numpy(observation.edge_features.indices.astype(np.int64)),
            torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1),
            torch.from_numpy(variable_features.astype(np.float32)),
        )
        action_set_tensor = torch.tensor(action_set, dtype=torch.int32)


        # pass state to actor and get distribution and create valid action mask
        logits = self.actor(*observation_tuple)
        mask = torch.zeros_like(logits, dtype=torch.bool)
        mask[action_set] = True

        if torch.all(mask == True):
            action_dist = Categorical(logits=logits)
        else:
            action_dist = MaskedCategorical(logits=logits, mask=mask)

        
        # sample from action distribution
        action = action_dist.sample().item()

        return action