import numpy as np class StaticFns: @staticmethod def termination_fn(obs, act, next_obs): assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 done = np.array([False]).repeat(len(obs)) done = done[:,None] return done