mvq / runner / build.py
build.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

from .vae import VAERunner
from .seq import SEQRunner

def get_runners(cfg, stage):
  if stage == 'vae':
    runner = VAERunner
  else:
    runner = SEQRunner
  
  return runner(cfg)