# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- from .vae import VAERunner from .seq import SEQRunner def get_runners(cfg, stage): if stage == 'vae': runner = VAERunner else: runner = SEQRunner return runner(cfg)