# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
# Add project root directory to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from collections import defaultdict
import datetime
from concurrent import futures
import time
import json
from absl import app, flags
import logging
from diffusers import StableDiffusion3Pipeline
import numpy as np
import flow_grpo.rewards
from flow_grpo.stat_tracking import PerPromptStatTracker
from flow_grpo.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob
from flow_grpo.diffusers_patch.train_dreambooth_lora_sd3 import encode_prompt
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
from peft import LoraConfig, get_peft_model, PeftModel
import random
from torch.utils.data import Dataset, DataLoader, Sampler
from flow_grpo.ema import EMAModuleWrapper
from ml_collections import config_flags
from torch.cuda.amp import GradScaler, autocast as torch_autocast
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def setup_distributed(rank, lock_rank, world_size):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(lock_rank)
def cleanup_distributed():
dist.destroy_process_group()
def is_main_process(rank):
return rank == 0
def set_seed(seed: int, rank: int = 0):
random.seed(seed + rank)
np.random.seed(seed + rank)
torch.manual_seed(seed + rank)
torch.cuda.manual_seed_all(seed + rank)
class TextPromptDataset(Dataset):
def __init__(self, dataset, split="train"):
self.file_path = os.path.join(dataset, f"{split}.txt")
with open(self.file_path, "r") as f:
self.prompts = [line.strip() for line in f.readlines()]
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return {"prompt": self.prompts[idx], "metadata": {}}
@staticmethod
def collate_fn(examples):
prompts = [example["prompt"] for example in examples]
metadatas = [example["metadata"] for example in examples]
return prompts, metadatas
class GenevalPromptDataset(Dataset):
def __init__(self, dataset, split="train"):
self.file_path = os.path.join(dataset, f"{split}_metadata.jsonl")
with open(self.file_path, "r", encoding="utf-8") as f:
self.metadatas = [json.loads(line) for line in f]
self.prompts = [item["prompt"] for item in self.metadatas]
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]}
@staticmethod
def collate_fn(examples):
prompts = [example["prompt"] for example in examples]
metadatas = [example["metadata"] for example in examples]
return prompts, metadatas
class DistributedKRepeatSampler(Sampler):
def __init__(self, dataset, batch_size, k, num_replicas, rank, seed=0):
self.dataset = dataset
self.batch_size = batch_size
self.k = k # Repeat each sample k times 24
self.num_replicas = num_replicas # Number of GPUs 8
self.rank = rank
self.seed = seed
self.total_samples = self.num_replicas * self.batch_size # 8 * 9 = 72
assert (
self.total_samples % self.k == 0
), f"k can not div n*b, k{k}-num_replicas{num_replicas}-batch_size{batch_size}"
self.m = self.total_samples // self.k # 72/24=3
self.epoch = 0
def __iter__(self):
while True:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)[: self.m].tolist() # 1. Randomly select m unique prompts
repeated_indices = [idx for idx in indices for _ in range(self.k)] # 2. Repeat each prompt k times (RL specific)
# 3. Shuffle indices after repetition
shuffled_indices = torch.randperm(len(repeated_indices), generator=g).tolist()
shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
# 4. Distribute to each GPU
per_card_samples = []
for i in range(self.num_replicas):
start = i * self.batch_size
end = start + self.batch_size
per_card_samples.append(shuffled_samples[start:end])
yield per_card_samples[self.rank] # 5. Return data for current GPU
def set_epoch(self, epoch):
self.epoch = epoch
def gather_tensor_to_all(tensor, world_size):
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, tensor)
return torch.cat(gathered_tensors, dim=0).cpu()
def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, max_sequence_length)
prompt_embeds = prompt_embeds.to(device)
pooled_prompt_embeds = pooled_prompt_embeds.to(device)
return prompt_embeds, pooled_prompt_embeds
def return_decay(step, decay_type):
if decay_type == 0:
flat = 0
uprate = 0.0
uphold = 0.0
elif decay_type == 1:
flat = 0
uprate = 0.001
uphold = 0.5
elif decay_type == 2:
flat = 75
uprate = 0.0075
uphold = 0.999
else:
assert False
if step < flat:
return 0.0
else:
decay = (step - flat) * uprate
return min(decay, uphold)
def calculate_zero_std_ratio(prompts, gathered_rewards):
prompt_array = np.array(prompts)
unique_prompts, inverse_indices, counts = np.unique(prompt_array, return_inverse=True, return_counts=True)
grouped_rewards = gathered_rewards["avg"][np.argsort(inverse_indices), 0]
split_indices = np.cumsum(counts)[:-1]
reward_groups = np.split(grouped_rewards, split_indices)
prompt_std_devs = np.array([np.std(group) for group in reward_groups])
zero_std_count = np.count_nonzero(prompt_std_devs == 0)
zero_std_ratio = zero_std_count / len(prompt_std_devs)
return zero_std_ratio, prompt_std_devs.mean()
def eval_fn(
pipeline,
test_dataloader,
text_encoders,
tokenizers,
config,
device,
rank,
world_size,
global_step,
reward_fn,
executor,
mixed_precision_dtype,
ema,
transformer_trainable_parameters,
):
if config.train.ema and ema is not None:
ema.copy_ema_to(transformer_trainable_parameters, store_temp=True)
pipeline.transformer.eval()
neg_prompt_embed, neg_pooled_prompt_embed = compute_text_embeddings(
[""], text_encoders, tokenizers, max_sequence_length=128, device=device
)
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.test_batch_size, 1, 1)
sample_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.sample.test_batch_size, 1)
all_rewards = defaultdict(list)
test_sampler = (
DistributedSampler(test_dataloader.dataset, num_replicas=world_size, rank=rank, shuffle=False)
if world_size > 1
else None
)
eval_loader = DataLoader(
test_dataloader.dataset,
batch_size=config.sample.test_batch_size, # This is per-GPU batch size
sampler=test_sampler,
collate_fn=test_dataloader.collate_fn,
num_workers=test_dataloader.num_workers,
)
for test_batch in tqdm(
eval_loader,
desc="Eval: ",
disable=not is_main_process(rank),
position=0,
):
prompts, prompt_metadata = test_batch
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers, max_sequence_length=128, device=device
)
current_batch_size = len(prompt_embeds)
if current_batch_size < len(sample_neg_prompt_embeds): # Handle last batch
current_sample_neg_prompt_embeds = sample_neg_prompt_embeds[:current_batch_size]
current_sample_neg_pooled_prompt_embeds = sample_neg_pooled_prompt_embeds[:current_batch_size]
else:
current_sample_neg_prompt_embeds = sample_neg_prompt_embeds
current_sample_neg_pooled_prompt_embeds = sample_neg_pooled_prompt_embeds
with torch_autocast(enabled=(config.mixed_precision in ["fp16", "bf16"]), dtype=mixed_precision_dtype):
with torch.no_grad():
images, _, _ = pipeline_with_logprob(
pipeline,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=current_sample_neg_prompt_embeds,
negative_pooled_prompt_embeds=current_sample_neg_pooled_prompt_embeds,
num_inference_steps=config.sample.eval_num_steps,
guidance_scale=config.sample.guidance_scale,
output_type="pt",
height=config.resolution,
width=config.resolution,
noise_level=config.sample.noise_level,
deterministic=True,
solver="flow",
model_type="sd3",
)
rewards_future = executor.submit(reward_fn, images, prompts, prompt_metadata, only_strict=False)
time.sleep(0)
rewards, reward_metadata = rewards_future.result()
for key, value in rewards.items():
rewards_tensor = torch.as_tensor(value, device=device).float()
gathered_value = gather_tensor_to_all(rewards_tensor, world_size)
all_rewards[key].append(gathered_value.numpy())
if is_main_process(rank):
final_rewards = {key: np.concatenate(value_list) for key, value_list in all_rewards.items()}
images_to_log = images.cpu()
prompts_to_log = prompts
with tempfile.TemporaryDirectory() as tmpdir:
num_samples_to_log = min(15, len(images_to_log))
for idx in range(num_samples_to_log):
image = images_to_log[idx].float()
pil = Image.fromarray((image.numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
pil = pil.resize((config.resolution, config.resolution))
pil.save(os.path.join(tmpdir, f"{idx}.jpg"))
sampled_prompts_log = [prompts_to_log[i] for i in range(num_samples_to_log)]
sampled_rewards_log = [{k: final_rewards[k][i] for k in final_rewards} for i in range(num_samples_to_log)]
wandb.log(
{
"eval_images": [
wandb.Image(
os.path.join(tmpdir, f"{idx}.jpg"),
caption=f"{prompt:.1000} | "
+ " | ".join(f"{k}: {v:.2f}" for k, v in reward.items() if v != -10),
)
for idx, (prompt, reward) in enumerate(zip(sampled_prompts_log, sampled_rewards_log))
],
**{f"eval_reward_{key}": np.mean(value[value != -10]) for key, value in final_rewards.items()},
},
step=global_step,
)
if config.train.ema and ema is not None:
ema.copy_temp_to(transformer_trainable_parameters)
if world_size > 1:
dist.barrier()
def save_ckpt(
save_dir, transformer_ddp, global_step, rank, ema, transformer_trainable_parameters, config, optimizer, scaler
):
if is_main_process(rank):
save_root = os.path.join(save_dir, "checkpoints", f"checkpoint-{global_step}")
save_root_lora = os.path.join(save_root, "lora")
os.makedirs(save_root_lora, exist_ok=True)
model_to_save = transformer_ddp.module
if config.train.ema and ema is not None:
ema.copy_ema_to(transformer_trainable_parameters, store_temp=True)
model_to_save.save_pretrained(save_root_lora) # For LoRA/PEFT models
torch.save(optimizer.state_dict(), os.path.join(save_root, "optimizer.pt"))
if scaler is not None:
torch.save(scaler.state_dict(), os.path.join(save_root, "scaler.pt"))
if config.train.ema and ema is not None:
ema.copy_temp_to(transformer_trainable_parameters)
logger.info(f"Saved checkpoint to {save_root}")
def main(_):
config = FLAGS.config
# --- Distributed Setup ---
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
setup_distributed(rank, local_rank, world_size)
device = torch.device(f"cuda:{local_rank}")
unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")
if not config.run_name:
config.run_name = unique_id
else:
config.run_name += "_" + unique_id
# --- WandB Init (only on main process) ---
if is_main_process(rank):
log_dir = os.path.join(config.logdir, config.run_name)
os.makedirs(log_dir, exist_ok=True)
wandb.init(project="flow-grpo", name=config.run_name, config=config.to_dict(), dir=log_dir)
logger.info(f"\n{config}")
set_seed(config.seed, rank) # Pass rank for different seeds per process
# --- Mixed Precision Setup ---
mixed_precision_dtype = None
if config.mixed_precision == "fp16":
mixed_precision_dtype = torch.float16
elif config.mixed_precision == "bf16":
mixed_precision_dtype = torch.bfloat16
enable_amp = mixed_precision_dtype is not None
scaler = GradScaler(enabled=enable_amp)
# --- Load pipeline and models ---
pipeline = StableDiffusion3Pipeline.from_pretrained(config.pretrained.model)
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
pipeline.text_encoder_2.requires_grad_(False)
pipeline.text_encoder_3.requires_grad_(False)
pipeline.transformer.requires_grad_(not config.use_lora)
text_encoders = [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.text_encoder_3]
tokenizers = [pipeline.tokenizer, pipeline.tokenizer_2, pipeline.tokenizer_3]
pipeline.safety_checker = None
pipeline.set_progress_bar_config(
position=1,
disable=not is_main_process(rank),
leave=False,
desc="Timestep",
dynamic_ncols=True,
)
text_encoder_dtype = mixed_precision_dtype if enable_amp else torch.float32
pipeline.vae.to(device, dtype=torch.float32) # VAE usually fp32
pipeline.text_encoder.to(device, dtype=text_encoder_dtype)
pipeline.text_encoder_2.to(device, dtype=text_encoder_dtype)
pipeline.text_encoder_3.to(device, dtype=text_encoder_dtype)
transformer = pipeline.transformer.to(device)
if config.use_lora:
target_modules = [
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_out.0",
"attn.to_q",
"attn.to_v",
]
transformer_lora_config = LoraConfig(
r=32, lora_alpha=64, init_lora_weights="gaussian", target_modules=target_modules
)
if config.train.lora_path:
transformer = PeftModel.from_pretrained(transformer, config.train.lora_path)
transformer.set_adapter("default")
else:
transformer = get_peft_model(transformer, transformer_lora_config)
transformer.add_adapter("old", transformer_lora_config)
transformer.set_adapter("default") # default is the new strategy for training, old strategy is used for sampling
transformer_ddp = DDP(transformer, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)
transformer_ddp.module.set_adapter("default")
transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer_ddp.module.parameters()))
transformer_ddp.module.set_adapter("old")
old_transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, transformer_ddp.module.parameters()))
transformer_ddp.module.set_adapter("default")
if config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# --- Optimizer ---
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
transformer_trainable_parameters, # Use params from original model for optimizer
lr=config.train.learning_rate,
betas=(config.train.adam_beta1, config.train.adam_beta2),
weight_decay=config.train.adam_weight_decay,
eps=config.train.adam_epsilon,
)
# --- Datasets and Dataloaders ---
if config.prompt_fn == "general_ocr":
train_dataset = TextPromptDataset(config.dataset, "train")
test_dataset = TextPromptDataset(config.dataset, "test")
elif config.prompt_fn == "geneval":
train_dataset = GenevalPromptDataset(config.dataset, "train")
test_dataset = GenevalPromptDataset(config.dataset, "test")
else:
raise NotImplementedError("Prompt function not supported with dataset")
train_sampler = DistributedKRepeatSampler(
dataset=train_dataset,
batch_size=config.sample.train_batch_size, # This is per-GPU batch size
k=config.sample.num_image_per_prompt,
num_replicas=world_size,
rank=rank,
seed=config.seed,
)
train_dataloader = DataLoader(
train_dataset, batch_sampler=train_sampler, num_workers=0, collate_fn=train_dataset.collate_fn, pin_memory=True
)
test_sampler = (
DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False) if world_size > 1 else None
)
test_dataloader = DataLoader(
test_dataset,
batch_size=config.sample.test_batch_size, # Per-GPU 9
sampler=test_sampler, # Use distributed sampler for eval
collate_fn=test_dataset.collate_fn,
num_workers=0,
pin_memory=True,
)
# --- Prompt Embeddings ---
neg_prompt_embed, neg_pooled_prompt_embed = compute_text_embeddings(
[""], text_encoders, tokenizers, max_sequence_length=128, device=device
)
sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.train_batch_size, 1, 1)
train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1)
sample_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.sample.train_batch_size, 1)
train_neg_pooled_prompt_embeds = neg_pooled_prompt_embed.repeat(config.train.batch_size, 1)
if config.sample.num_image_per_prompt == 1:
config.per_prompt_stat_tracking = False
if config.per_prompt_stat_tracking:
stat_tracker = PerPromptStatTracker(config.sample.global_std)
else:
assert False
executor = futures.ThreadPoolExecutor(max_workers=8) # Async reward computation
# Train!
samples_per_epoch = config.sample.train_batch_size * world_size * config.sample.num_batches_per_epoch
total_train_batch_size = config.train.batch_size * world_size * config.train.gradient_accumulation_steps
logger.info("***** Running training *****")
'''
***** Running training *****
Total number of samples per epoch = 1152
Total train batch size (w. parallel, distributed & accumulation) = 1152
Number of gradient updates per inner epoch = 1
Num Epochs = 100000
Number of inner epochs = 1
Sample batch size per device = 9
'''
logger.info(f" Num Epochs = {config.num_epochs}") # 100000
logger.info(f" Sample batch size per device = {config.sample.train_batch_size}") # 9
logger.info(f" Train batch size per device = {config.train.batch_size}") # 9
logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") # 1
logger.info("")
logger.info(f" Total number of samples per epoch = {samples_per_epoch}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}")
logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}")
reward_fn = getattr(flow_grpo.rewards, "multi_score")(device, config.reward_fn) # Pass device
eval_reward_fn = getattr(flow_grpo.rewards, "multi_score")(device, config.reward_fn) # Pass device
# --- Resume from checkpoint ---
first_epoch = 0
global_step = 0
if config.resume_from:
logger.info(f"Resuming from {config.resume_from}")
# Assuming checkpoint dir contains lora, optimizer.pt, scaler.pt
lora_path = os.path.join(config.resume_from, "lora")
if os.path.exists(lora_path): # Check if it's a PEFT model save
transformer_ddp.module.load_adapter(lora_path, adapter_name="default", is_trainable=True)
transformer_ddp.module.load_adapter(lora_path, adapter_name="old", is_trainable=False)
else: # Try loading full state dict if it's not a PEFT save structure
model_ckpt_path = os.path.join(config.resume_from, "transformer_model.pt") # Or specific name
if os.path.exists(model_ckpt_path):
transformer_ddp.module.load_state_dict(torch.load(model_ckpt_path, map_location=device))
opt_path = os.path.join(config.resume_from, "optimizer.pt")
if os.path.exists(opt_path):
optimizer.load_state_dict(torch.load(opt_path, map_location=device))
scaler_path = os.path.join(config.resume_from, "scaler.pt")
if os.path.exists(scaler_path) and enable_amp:
scaler.load_state_dict(torch.load(scaler_path, map_location=device))
# Extract epoch and step from checkpoint name, e.g., "checkpoint-1000" -> global_step = 1000
try:
global_step = int(os.path.basename(config.resume_from).split("-")[-1])
logger.info(f"Resumed global_step to {global_step}. Epoch estimation might be needed.")
except ValueError:
logger.warning(
f"Could not parse global_step from checkpoint name: {config.resume_from}. Starting global_step from 0."
)
global_step = 0
ema = None
if config.train.ema:
ema = EMAModuleWrapper(transformer_trainable_parameters, decay=0.9, update_step_interval=1, device=device)
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction) # 25
logger.info("***** Running training *****")
train_iter = iter(train_dataloader)
optimizer.zero_grad()
for src_param, tgt_param in zip(
transformer_trainable_parameters, old_transformer_trainable_parameters, strict=True
):
tgt_param.data.copy_(src_param.detach().data)
assert src_param is not tgt_param
for epoch in range(first_epoch, config.num_epochs):
if hasattr(train_sampler, "set_epoch"):
train_sampler.set_epoch(epoch)
# SAMPLING
pipeline.transformer.eval()
samples_data_list = []
for i in tqdm(
range(config.sample.num_batches_per_epoch),
desc=f"Epoch {epoch}: sampling",
disable=not is_main_process(rank),
position=0,
):
transformer_ddp.module.set_adapter("default")
if hasattr(train_sampler, "set_epoch") and isinstance(train_sampler, DistributedKRepeatSampler):
train_sampler.set_epoch(epoch * config.sample.num_batches_per_epoch + i)
prompts, prompt_metadata = next(train_iter)
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers, max_sequence_length=128, device=device
) # prompt_embeds: [bsz per gpu, token len, 4096], pooled_prompt_embeds: [bsz, 2048]
prompt_ids = tokenizers[0](
prompts, padding="max_length", max_length=256, truncation=True, return_tensors="pt"
).input_ids.to(device) # prompt_ids: [bsz, 256] will pad 49407 later
# if i == 0 and epoch % config.eval_freq == 0 and not config.debug: # Evaluate once every 10 epochs
# eval_fn(
# pipeline,
# test_dataloader,
# text_encoders,
# tokenizers,
# config,
# device,
# rank,
# world_size,
# global_step,
# eval_reward_fn,
# executor,
# mixed_precision_dtype,
# ema,
# transformer_trainable_parameters,
# )
if i == 0 and epoch % config.save_freq == 0 and is_main_process(rank) and not config.debug:
save_ckpt(
config.save_dir,
transformer_ddp,
global_step,
rank,
ema,
transformer_trainable_parameters,
config,
optimizer,
scaler,
)
# import ipdb; ipdb.set_trace()
transformer_ddp.module.set_adapter("old")
with torch_autocast(enabled=enable_amp, dtype=mixed_precision_dtype):
with torch.no_grad(): # Based on standard Stable Diffusion 3 / Flux generation process, extra log probabilities for each denoising step are calculated and returned.
images, latents, _ = pipeline_with_logprob( # images are output pictures, latents are intermediate variables during denoising
pipeline, # (self)
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds[: len(prompts)],
negative_pooled_prompt_embeds=sample_neg_pooled_prompt_embeds[: len(prompts)],
num_inference_steps=config.sample.num_steps, # 40
guidance_scale=config.sample.guidance_scale, # 1.0
output_type="pt",
height=config.resolution,
width=config.resolution,
noise_level=config.sample.noise_level, # 0.7
deterministic=config.sample.deterministic, # True
solver=config.sample.solver, # "dpm2"
model_type="sd3",
)
transformer_ddp.module.set_adapter("default")
latents = torch.stack(latents, dim=1) # torch.Size([bsz, 26, 16, 64, 64]) images: torch.Size([bsz, 3, 512, 512])
timesteps = pipeline.scheduler.timesteps.repeat(len(prompts), 1).to(device)
rewards_future = executor.submit(reward_fn, images, prompts, prompt_metadata, only_strict=True)
time.sleep(0)
samples_data_list.append(
{
"prompt_ids": prompt_ids,
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"timesteps": timesteps,
"next_timesteps": torch.concatenate([timesteps[:, 1:], torch.zeros_like(timesteps[:, :1])], dim=1),
"latents_clean": latents[:, -1],
"rewards_future": rewards_future, # Store future - dictionary containing 4 keys: {'clipscore':, 'pickscore':, 'hpsv2': , 'avg':}
}
)
# import ipdb
# if not dist.is_initialized() or dist.get_rank() == 0:
# ipdb.set_trace()
for sample_item in tqdm(
samples_data_list, desc="Waiting for rewards", disable=not is_main_process(rank), position=0
):
rewards, reward_metadata = sample_item["rewards_future"].result()
sample_item["rewards"] = {k: torch.as_tensor(v, device=device).float() for k, v in rewards.items()}
del sample_item["rewards_future"]
# Collate samples
collated_samples = {
k: (
torch.cat([s[k] for s in samples_data_list], dim=0)
if not isinstance(samples_data_list[0][k], dict)
else {sk: torch.cat([s[k][sk] for s in samples_data_list], dim=0) for sk in samples_data_list[0][k]}
)
for k in samples_data_list[0].keys()
}
# Logging images (main process)
if epoch % 10 == 0 and is_main_process(rank):
images_to_log = images.cpu() # from last sampling batch on this rank
prompts_to_log = prompts # from last sampling batch on this rank
rewards_to_log = collated_samples["rewards"]["avg"][-len(images_to_log) :].cpu()
with tempfile.TemporaryDirectory() as tmpdir:
num_to_log = min(15, len(images_to_log))
for idx in range(num_to_log): # log first N
img_data = images_to_log[idx]
pil = Image.fromarray((img_data.numpy().transpose(1, 2, 0) * 255).astype(np.uint8))
pil = pil.resize((config.resolution, config.resolution))
pil.save(os.path.join(tmpdir, f"{idx}.jpg"))
wandb.log(
{
"images": [
wandb.Image(
os.path.join(tmpdir, f"{idx}.jpg"),
caption=f"{prompts_to_log[idx]:.100} | avg: {rewards_to_log[idx]:.2f}",
)
for idx in range(num_to_log)
],
},
step=global_step,
)
collated_samples["rewards"]["avg"] = (
collated_samples["rewards"]["avg"].unsqueeze(1).repeat(1, num_train_timesteps)
)
# Gather rewards across processes
gathered_rewards_dict = {}
for key, value_tensor in collated_samples["rewards"].items():
gathered_rewards_dict[key] = gather_tensor_to_all(value_tensor, world_size).numpy()
if is_main_process(rank): # logging
wandb.log(
{
"epoch": epoch,
**{
f"reward_{k}": v.mean()
for k, v in gathered_rewards_dict.items()
if "_strict_accuracy" not in k and "_accuracy" not in k
},
},
step=global_step,
)
if config.per_prompt_stat_tracking:
prompt_ids_all = gather_tensor_to_all(collated_samples["prompt_ids"], world_size)
prompts_all_decoded = pipeline.tokenizer.batch_decode(
prompt_ids_all.cpu().numpy(), skip_special_tokens=True
)
# Stat tracker update expects numpy arrays for rewards
advantages = stat_tracker.update(prompts_all_decoded, gathered_rewards_dict["avg"])
if is_main_process(rank):
group_size, trained_prompt_num = stat_tracker.get_stats()
zero_std_ratio, reward_std_mean = calculate_zero_std_ratio(prompts_all_decoded, gathered_rewards_dict)
wandb.log(
{
"group_size": group_size,
"trained_prompt_num": trained_prompt_num,
"zero_std_ratio": zero_std_ratio,
"reward_std_mean": reward_std_mean,
"mean_reward_100": stat_tracker.get_mean_of_top_rewards(100),
"mean_reward_75": stat_tracker.get_mean_of_top_rewards(75),
"mean_reward_50": stat_tracker.get_mean_of_top_rewards(50),
"mean_reward_25": stat_tracker.get_mean_of_top_rewards(25),
"mean_reward_10": stat_tracker.get_mean_of_top_rewards(10),
},
step=global_step,
)
stat_tracker.clear()
else:
avg_rewards_all = gathered_rewards_dict["avg"]
advantages = (avg_rewards_all - avg_rewards_all.mean()) / (avg_rewards_all.std() + 1e-4)
# Distribute advantages back to processes
samples_per_gpu = collated_samples["timesteps"].shape[0]
if advantages.ndim == 1:
advantages = advantages[:, None]
if advantages.shape[0] == world_size * samples_per_gpu:
collated_samples["advantages"] = torch.from_numpy(
advantages.reshape(world_size, samples_per_gpu, -1)[rank]
).to(device)
else:
assert False
if is_main_process(rank):
logger.info(f"Advantages mean: {collated_samples['advantages'].abs().mean().item()}")
del collated_samples["rewards"]
del collated_samples["prompt_ids"]
num_batches = config.sample.num_batches_per_epoch * config.sample.train_batch_size // config.train.batch_size
filtered_samples = collated_samples
total_batch_size_filtered, num_timesteps_filtered = filtered_samples["timesteps"].shape
# TRAINING
transformer_ddp.train() # Sets DDP model and its submodules to train mode.
# Total number of backward passes before an optimizer step
effective_grad_accum_steps = config.train.gradient_accumulation_steps * num_train_timesteps
current_accumulated_steps = 0 # Counter for backward passes
gradient_update_times = 0
for inner_epoch in range(config.train.num_inner_epochs): # 1
perm = torch.randperm(total_batch_size_filtered, device=device)
shuffled_filtered_samples = {k: v[perm] for k, v in filtered_samples.items()}
perms_time = torch.stack(
[torch.randperm(num_timesteps_filtered, device=device) for _ in range(total_batch_size_filtered)]
)
for key in ["timesteps", "next_timesteps"]:
shuffled_filtered_samples[key] = shuffled_filtered_samples[key][
torch.arange(total_batch_size_filtered, device=device)[:, None], perms_time
]
training_batch_size = total_batch_size_filtered // num_batches
samples_batched_list = []
for k_batch in range(num_batches):
batch_dict = {}
start = k_batch * training_batch_size
end = (k_batch + 1) * training_batch_size
for key, val_tensor in shuffled_filtered_samples.items():
batch_dict[key] = val_tensor[start:end]
samples_batched_list.append(batch_dict)
info_accumulated = defaultdict(list) # For accumulating stats over one grad acc cycle
for i, train_sample_batch in tqdm(
list(enumerate(samples_batched_list)),
desc=f"Epoch {epoch}.{inner_epoch}: training",
position=0,
disable=not is_main_process(rank),
):
current_micro_batch_size = len(train_sample_batch["prompt_embeds"])
if config.sample.guidance_scale > 1.0:
embeds = torch.cat(
[train_neg_prompt_embeds[:current_micro_batch_size], train_sample_batch["prompt_embeds"]]
)
pooled_embeds = torch.cat(
[
train_neg_pooled_prompt_embeds[:current_micro_batch_size],
train_sample_batch["pooled_prompt_embeds"],
]
)
else:
embeds = train_sample_batch["prompt_embeds"]
pooled_embeds = train_sample_batch["pooled_prompt_embeds"]
# Loop over timesteps for this micro-batch
for j_idx, j_timestep_orig_idx in tqdm(
enumerate(range(num_train_timesteps)),
desc="Timestep",
position=1,
leave=False,
disable=not is_main_process(rank),
):
assert j_idx == j_timestep_orig_idx
x0 = train_sample_batch["latents_clean"]
t = train_sample_batch["timesteps"][:, j_idx] / 1000.0
t_expanded = t.view(-1, *([1] * (len(x0.shape) - 1)))
noise = torch.randn_like(x0.float())
xt = (1 - t_expanded) * x0 + t_expanded * noise
with torch_autocast(enabled=enable_amp, dtype=mixed_precision_dtype):
transformer_ddp.module.set_adapter("old")
with torch.no_grad():
# prediction v
old_prediction = transformer_ddp(
hidden_states=xt,
timestep=train_sample_batch["timesteps"][:, j_idx],
encoder_hidden_states=embeds,
pooled_projections=pooled_embeds,
return_dict=False,
)[0].detach()
transformer_ddp.module.set_adapter("default")
# prediction v
forward_prediction = transformer_ddp(
hidden_states=xt,
timestep=train_sample_batch["timesteps"][:, j_idx],
encoder_hidden_states=embeds,
pooled_projections=pooled_embeds,
return_dict=False,
)[0]
with torch.no_grad(): # Reference model part
# For LoRA, disable adapter.
if config.use_lora:
with transformer_ddp.module.disable_adapter():
ref_forward_prediction = transformer_ddp(
hidden_states=xt,
timestep=train_sample_batch["timesteps"][:, j_idx],
encoder_hidden_states=embeds,
pooled_projections=pooled_embeds,
return_dict=False,
)[0]
transformer_ddp.module.set_adapter("default")
else: # Full model - this requires a frozen copy of the model
assert False
loss_terms = {}
# Policy Gradient Loss
advantages_clip = torch.clamp(
train_sample_batch["advantages"][:, j_idx],
-config.train.adv_clip_max,
config.train.adv_clip_max,
)
if hasattr(config.train, "adv_mode"):
if config.train.adv_mode == "positive_only":
advantages_clip = torch.clamp(advantages_clip, 0, config.train.adv_clip_max)
elif config.train.adv_mode == "negative_only":
advantages_clip = torch.clamp(advantages_clip, -config.train.adv_clip_max, 0)
elif config.train.adv_mode == "one_only":
advantages_clip = torch.where(
advantages_clip > 0, torch.ones_like(advantages_clip), torch.zeros_like(advantages_clip)
)
elif config.train.adv_mode == "binary":
advantages_clip = torch.sign(advantages_clip)
# normalize advantage
normalized_advantages_clip = (advantages_clip / config.train.adv_clip_max) / 2.0 + 0.5
r = torch.clamp(normalized_advantages_clip, 0, 1)
loss_terms["x0_norm"] = torch.mean(x0**2).detach()
loss_terms["x0_norm_max"] = torch.max(x0**2).detach()
loss_terms["old_deviate"] = torch.mean((forward_prediction - old_prediction) ** 2).detach()
loss_terms["old_deviate_max"] = torch.max((forward_prediction - old_prediction) ** 2).detach()
positive_prediction = config.beta * forward_prediction + (1 - config.beta) * old_prediction.detach()
implicit_negative_prediction = (
1.0 + config.beta
) * old_prediction.detach() - config.beta * forward_prediction
# adaptive weighting
x0_prediction = xt - t_expanded * positive_prediction
with torch.no_grad():
weight_factor = (
torch.abs(x0_prediction.double() - x0.double())
.mean(dim=tuple(range(1, x0.ndim)), keepdim=True)
.clip(min=0.00001)
)
positive_loss = ((x0_prediction - x0) ** 2 / weight_factor).mean(dim=tuple(range(1, x0.ndim)))
negative_x0_prediction = xt - t_expanded * implicit_negative_prediction
with torch.no_grad():
negative_weight_factor = (
torch.abs(negative_x0_prediction.double() - x0.double())
.mean(dim=tuple(range(1, x0.ndim)), keepdim=True)
.clip(min=0.00001)
)
negative_loss = ((negative_x0_prediction - x0) ** 2 / negative_weight_factor).mean(
dim=tuple(range(1, x0.ndim))
)
ori_policy_loss = r * positive_loss / config.beta + (1.0 - r) * negative_loss / config.beta
policy_loss = (ori_policy_loss * config.train.adv_clip_max).mean()
loss = policy_loss
loss_terms["policy_loss"] = policy_loss.detach()
loss_terms["unweighted_policy_loss"] = ori_policy_loss.mean().detach()
kl_div_loss = ((forward_prediction - ref_forward_prediction) ** 2).mean(
dim=tuple(range(1, x0.ndim))
)
loss += config.train.beta * torch.mean(kl_div_loss)
kl_div_loss = torch.mean(kl_div_loss)
loss_terms["kl_div_loss"] = torch.mean(kl_div_loss).detach()
loss_terms["kl_div"] = torch.mean(
((forward_prediction - ref_forward_prediction) ** 2).mean(dim=tuple(range(1, x0.ndim)))
).detach()
loss_terms["old_kl_div"] = torch.mean(
((old_prediction - ref_forward_prediction) ** 2).mean(dim=tuple(range(1, x0.ndim)))
).detach()
loss_terms["total_loss"] = loss.detach()
# Scale loss for gradient accumulation and DDP (DDP averages grads, so no need to divide by world_size here)
scaled_loss = loss / effective_grad_accum_steps
if mixed_precision_dtype == torch.float16:
scaler.scale(scaled_loss).backward() # one accumulation
else:
scaled_loss.backward()
current_accumulated_steps += 1
for k_info, v_info in loss_terms.items():
info_accumulated[k_info].append(v_info)
if current_accumulated_steps % effective_grad_accum_steps == 0:
if mixed_precision_dtype == torch.float16:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(transformer_ddp.module.parameters(), config.train.max_grad_norm)
if mixed_precision_dtype == torch.float16:
scaler.step(optimizer)
else:
optimizer.step()
gradient_update_times += 1
if mixed_precision_dtype == torch.float16:
scaler.update()
optimizer.zero_grad()
log_info = {k: torch.mean(torch.stack(v_list)).item() for k, v_list in info_accumulated.items()}
info_tensor = torch.tensor([log_info[k] for k in sorted(log_info.keys())], device=device)
dist.all_reduce(info_tensor, op=dist.ReduceOp.AVG)
reduced_log_info = {k: info_tensor[ki].item() for ki, k in enumerate(sorted(log_info.keys()))}
if is_main_process(rank):
wandb.log(
{
"step": global_step,
"gradient_update_times": gradient_update_times,
"epoch": epoch,
"inner_epoch": inner_epoch,
**reduced_log_info,
}
)
global_step += 1 # gradient step
info_accumulated = defaultdict(list) # Reset for next accumulation cycle
if (
config.train.ema
and ema is not None
and (current_accumulated_steps % effective_grad_accum_steps == 0)
):
ema.step(transformer_trainable_parameters, global_step)
if world_size > 1:
dist.barrier()
with torch.no_grad():
decay = return_decay(global_step, config.decay_type)
for src_param, tgt_param in zip(
transformer_trainable_parameters, old_transformer_trainable_parameters, strict=True
):
tgt_param.data.copy_(tgt_param.detach().data * decay + src_param.detach().clone().data * (1.0 - decay))
if is_main_process(rank):
wandb.finish()
cleanup_distributed()
if __name__ == "__main__":
app.run(main)