EECE571F-project / RL_ecole / random_agent.py
random_agent.py
Raw
import numpy as np
import torch

class RandomAgent:
    def __init__(self, name='random'):
        self.name = name

    def before_reset(self, model):
        pass

    def action_select(self, action_set, **kwargs):
        action_idx = np.random.choice([i for i in range(len(action_set))])
        return action_set[action_idx], action_idx
    

class GraphAgent(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(GraphAgent, self).__init__(*args, **kwargs)
        # TODO
    def forward(self, constraint_features, edge_indices, edge_features, variable_features):
        # TODO
        raise NotImplementedError