from utils.tools import * from utils.training_tools import * import optuna from train import train from utils.logger import Logger def read_sweep(file_path): with open(file_path, 'r') as file: config = yaml.safe_load(file) return config def objective(trial, train_args, sweep_args, train_data, val_data): params = {} for param_name, param_spec in sweep_args['parameters'].items(): if param_spec['distribution'] == 'log_uniform_values': params[param_name] = trial.suggest_float(param_name, param_spec['min'], param_spec['max'], log=True) elif param_spec['distribution'] == 'int_uniform': params[param_name] = trial.suggest_int(param_name, param_spec['min'], param_spec['max']) elif 'values' in param_spec: params[param_name] = trial.suggest_categorical(param_name, param_spec['values']) train_args[param_name] = params[param_name] train_args['SESSIONAME'] = f'Trial{trial.number + 1}' logger = Logger(train_args) model = get_model(train_args) best_loss = train(model, train_data, val_data, train_args, logger) logger.add_hyperparams(params, {sweep_args['metric']['name']: best_loss}) logger.close() return best_loss def main(train_args, sweep_args): train_dataset, val_dataset = get_dataset(train_args) direction = 'maximize' if sweep_args['metric']['goal'] == 'maximize' else 'minimize' os.makedirs('outputs/studies', exist_ok=True) study = optuna.create_study(f"sqlite:///outputs/studies/{train_args['PROJECTNAME']}.db", direction=direction, study_name=f'{train_args["SESSIONAME"]}', load_if_exists=True) study.optimize(lambda trial: objective(trial, train_args, sweep_args, train_dataset, val_dataset), n_trials=100) if __name__ == '__main__': train_args = parse_args() set_reproducibility(train_args['SEED']) sweep_args = read_sweep('./config/sweep_config.yaml') main(train_args, sweep_args)