import imp import os base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) def get_config(name): return globals()[name]() def _get_config(base_model="sd3", n_gpus=1, gradient_step_per_epoch=1, dataset="pickscore", reward_fn={}, name=""): config = base.get_config() assert base_model in ["sd3"] assert dataset in ["pickscore", "ocr", "geneval"] config.base_model = base_model config.dataset = os.path.join(os.getcwd(), f"dataset/{dataset}") if base_model == "sd3": config.pretrained.model = "/data/phd/hf_models/stable-diffusion-3.5-medium" config.sample.num_steps = 10 config.sample.eval_num_steps = 40 config.sample.guidance_scale = 4.5 config.resolution = 512 config.train.beta = 0.0001 config.sample.noise_level = 0.7 bsz = 8 config.sample.num_image_per_prompt = 8 num_groups = 48 while True: if bsz < 1: assert False, "Cannot find a proper batch size." if ( num_groups * config.sample.num_image_per_prompt % (n_gpus * bsz) == 0 and bsz * n_gpus % config.sample.num_image_per_prompt == 0 ): n_batch_per_epoch = num_groups * config.sample.num_image_per_prompt // (n_gpus * bsz) if n_batch_per_epoch % gradient_step_per_epoch == 0: config.sample.train_batch_size = bsz config.sample.num_batches_per_epoch = n_batch_per_epoch config.train.batch_size = config.sample.train_batch_size config.train.gradient_accumulation_steps = ( config.sample.num_batches_per_epoch // gradient_step_per_epoch ) break bsz -= 1 # special design, the test set has a total of 1018/2212/2048 for ocr/geneval/pickscore, to make gpu_num*bs*n as close as possible to it, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. config.sample.test_batch_size = 14 if dataset == "geneval" else 16 if n_gpus > 32: config.sample.test_batch_size = config.sample.test_batch_size // 2 config.prompt_fn = "geneval" if dataset == "geneval" else "general_ocr" config.run_name = f"nft_{base_model}_{name}" config.save_dir = f"logs/nft/{base_model}/{name}" config.reward_fn = reward_fn config.decay_type = 1 config.beta = 1.0 config.train.adv_mode = "all" config.sample.guidance_scale = 1.0 config.sample.deterministic = True config.sample.solver = "dpm2" return config def sd3_ocr(): reward_fn = { "ocr": 1.0, } config = _get_config( base_model="sd3", n_gpus=8, gradient_step_per_epoch=2, dataset="ocr", reward_fn=reward_fn, name="ocr" ) config.beta = 0.1 config.decay_type = 2 return config def sd3_geneval(): reward_fn = { "geneval": 1.0, } config = _get_config( base_model="sd3", n_gpus=8, gradient_step_per_epoch=1, dataset="geneval", reward_fn=reward_fn, name="geneval", ) return config def sd3_pickscore(): reward_fn = { "pickscore": 1.0, } config = _get_config( base_model="sd3", n_gpus=8, gradient_step_per_epoch=1, dataset="pickscore", reward_fn=reward_fn, name="pickscore", ) return config def sd3_hpsv2(): reward_fn = { "hpsv2": 1.0, } config = _get_config( base_model="sd3", n_gpus=8, gradient_step_per_epoch=1, dataset="pickscore", reward_fn=reward_fn, name="hpsv2" ) return config def sd3_multi_reward(): reward_fn = { "pickscore": 1.0, "hpsv2": 1.0, "clipscore": 1.0, } config = _get_config( base_model="sd3", n_gpus=1, gradient_step_per_epoch=1, dataset="pickscore", reward_fn=reward_fn, name="multi_reward", ) config.sample.num_steps = 25 config.beta = 0.1 return config