import numpy as np
import torch
class RandomAgent:
def __init__(self, name='random'):
self.name = name
def before_reset(self, model):
pass
def action_select(self, action_set, **kwargs):
action_idx = np.random.choice([i for i in range(len(action_set))])
return action_set[action_idx], action_idx
class GraphAgent(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(GraphAgent, self).__init__(*args, **kwargs)
# TODO
def forward(self, constraint_features, edge_indices, edge_features, variable_features):
# TODO
raise NotImplementedError