import numpy as np import torch class RandomAgent: def __init__(self, name='random'): self.name = name def before_reset(self, model): pass def action_select(self, action_set, **kwargs): action_idx = np.random.choice([i for i in range(len(action_set))]) return action_set[action_idx], action_idx class GraphAgent(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super(GraphAgent, self).__init__(*args, **kwargs) # TODO def forward(self, constraint_features, edge_indices, edge_features, variable_features): # TODO raise NotImplementedError