import imp import os base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) def get_config(name): return globals()[name]() def _get_config(base_model="prefRestore", n_gpus=1, gradient_step_per_epoch=1, dataset="pickscore", reward_fn={}, name=""): config = base.get_config() config.base_model = base_model config.dataset = os.path.join(os.getcwd(), f"dataset/{dataset}") config.pretrained.model = "" config.sample.num_steps = 10 config.sample.eval_num_steps = 30 config.sample.guidance_scale = 2.0 config.resolution = 512 config.train.beta = 0.0001 config.sample.noise_level = 0.7 bsz = 9 config.sample.num_image_per_prompt = 24 num_groups = 48 while True: if bsz < 1: assert False, "Cannot find a proper batch size." if ( num_groups * config.sample.num_image_per_prompt % (n_gpus * bsz) == 0 and bsz * n_gpus % config.sample.num_image_per_prompt == 0 ): n_batch_per_epoch = num_groups * config.sample.num_image_per_prompt // (n_gpus * bsz) if n_batch_per_epoch % gradient_step_per_epoch == 0: config.sample.train_batch_size = bsz config.sample.num_batches_per_epoch = n_batch_per_epoch config.train.batch_size = config.sample.train_batch_size config.train.gradient_accumulation_steps = ( config.sample.num_batches_per_epoch // gradient_step_per_epoch ) break bsz -= 1 # special design, the test set has a total of 1018/2212/2048 for ocr/geneval/pickscore, to make gpu_num*bs*n as close as possible to it, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. config.sample.test_batch_size = bsz if n_gpus > 32: config.sample.test_batch_size = config.sample.test_batch_size // 2 config.prompt_fn = "geneval" if dataset == "geneval" else "general_ocr" config.run_name = f"nft_{base_model}_{name}" config.save_dir = f"logs/nft/{base_model}/{name}" config.reward_fn = reward_fn config.decay_type = 1 config.beta = 1.0 config.train.adv_mode = "all" config.sample.guidance_scale = 1.0 config.sample.deterministic = True config.sample.solver = "dpm2" return config def pref_restore_multi_reward(): reward_fn = { "pickscore": 1.0, "hpsv2": 1.0, "clipscore": 1.0, } config = _get_config( base_model="prefRestore", n_gpus=8, gradient_step_per_epoch=1, dataset="restore_face", reward_fn=reward_fn, name="multi_reward", ) config.run_name = f"prefRestore_multi-reward_dosampleFalse" config.sample.num_steps = 30 config.beta = 0.1 return config def pref_restore_multi_reward_ffhq(): reward_fn = { "pickscore": 1.0, "hpsv2": 1.0, "clipscore": 1.0, } config = _get_config( base_model="prefRestore", n_gpus=8, gradient_step_per_epoch=1, dataset="restore_face_ffhq", reward_fn=reward_fn, name="multi_reward", ) config.run_name = f"prefRestore_multi-FFHQ-data" config.sample.num_steps = 30 config.beta = 0.1 return config