MIA-GCL / GCA / simple_param / sp.py
sp.py
Raw
from typing import Optional
import os.path as osp

import json
import yaml

import nni


class SimpleParam:
    def __init__(self, local_dir: str = 'param', default: Optional[dict] = None):
        if default is None:
            default = dict()

        self.local_dir = local_dir
        self.default = default

    def __call__(self, source: str, preprocess: str = 'none'):
        if source == 'nni':
            return {**self.default, **nni.get_next_parameter()}
        if source.startswith('local'):
            ts = source.split(':')
            assert len(ts) == 2, 'local parameter file should be specified in a form of `local:FILE_NAME`'
            path = ts[-1]
            path = osp.join(self.local_dir, path)
            if path.endswith('.json'):
                loaded = parse_json(path)
            elif path.endswith('.yaml') or path.endswith('.yml'):
                loaded = parse_yaml(path)
            else:
                raise Exception('Invalid file name. Should end with .yaml or .json.')

            if preprocess == 'nni':
                loaded = preprocess_nni(loaded)

            return {**self.default, **loaded}
        if source == 'default':
            return self.default

        raise Exception('invalid source')


def preprocess_nni(params: dict):
    def process_key(key: str):
        xs = key.split('/')
        if len(xs) == 3:
            return xs[1]
        elif len(xs) == 1:
            return key
        else:
            raise Exception('Unexpected param name ' + key)

    return {
        process_key(k): v for k, v in params.items()
    }


def parse_yaml(path: str):
    content = open(path).read()
    return yaml.load(content, Loader=yaml.Loader)


def parse_json(path: str):
    content = open(path).read()
    return json.loads(content)