import math
import torch
from diffusers.utils.torch_utils import randn_tensor
from typing import Optional, List
from dataclasses import dataclass
import torch.distributed as dist
import tqdm
from functools import partial
tqdm = partial(tqdm.tqdm, dynamic_ncols=True)
# Modified from MixGRPO
def run_sampling(
v_pred_fn,
z,
sigma_schedule,
solver="flow",
determistic=False,
eta=0.7,
pred_latents: Optional[torch.Tensor] = None,
):
assert solver in ["flow", "dance", "ddim", "dpm1", "dpm2"]
dtype = z.dtype
all_latents = [z]
all_log_probs = []
if "dpm" in solver:
order = int(solver[-1])
dpm_state = DPMState(order=order)
for i in tqdm(
range(len(sigma_schedule) - 1),
desc="Sampling Progress",
disable=not dist.is_initialized() or dist.get_rank() != 0,
):
sigma = sigma_schedule[i]
pred = v_pred_fn(z.to(dtype), sigma, pred_latents)
if solver == "flow":
z, pred_original, log_prob = flow_grpo_step(
model_output=pred.float(),
latents=z.float(),
eta=eta if not determistic else 0,
sigmas=sigma_schedule,
index=i,
prev_sample=None,
)
elif solver == "dance":
z, pred_original, log_prob = dance_grpo_step(
pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None
)
elif solver == "ddim":
z, pred_original, log_prob = ddim_step(
pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None
)
elif "dpm" in solver:
assert determistic
# Ensure sigma_schedule is a PyTorch tensor
import numpy as np
if isinstance(sigma_schedule, np.ndarray):
sigma_schedule = torch.from_numpy(sigma_schedule).to(z.device)
z, pred_original, log_prob = dpm_step(
order,
model_output=pred.float(),
sample=z.float(),
step_index=i,
timesteps=sigma_schedule[:-1],
sigmas=sigma_schedule,
dpm_state=dpm_state,
)
else:
assert False
z = z.to(dtype) # z torch.Size([B, 32, 16, 16])
all_latents.append(z)
all_log_probs.append(log_prob) # log_prob None
latents = z.to(dtype)
###############################################################################################
# import ipdb; ipdb.set_trace()
###############################################################################################
# all_latents = torch.stack(all_latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64)
# all_log_probs = torch.stack(all_log_probs, dim=1) # (batch_size, num_steps, 1)
return latents, all_latents, all_log_probs
def flow_grpo_step(
model_output: torch.Tensor,
latents: torch.Tensor,
eta: float,
sigmas: torch.Tensor,
index: int,
prev_sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
):
device = model_output.device
sigma = sigmas[index].to(device)
sigma_prev = sigmas[index + 1].to(device)
sigma_max = sigmas[1].item()
dt = sigma_prev - sigma # neg dt
pred_original_sample = latents - sigma * model_output
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * eta
if prev_sample is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
" `prev_sample` stays `None`."
)
prev_sample_mean = (
latents * (1 + std_dev_t**2 / (2 * sigma) * dt)
+ model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
)
if prev_sample is None:
variance_noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, pred_original_sample, log_prob
def dance_grpo_step(
model_output: torch.Tensor,
latents: torch.Tensor,
eta: float,
sigmas: torch.Tensor,
index: int,
prev_sample: torch.Tensor,
):
sigma = sigmas[index]
dsigma = sigmas[index + 1] - sigma # neg dt
prev_sample_mean = latents + dsigma * model_output
pred_original_sample = latents - sigma * model_output
delta_t = sigma - sigmas[index + 1] # pos -dt
std_dev_t = eta * math.sqrt(delta_t)
score_estimate = -(latents - pred_original_sample * (1 - sigma)) / sigma**2
log_term = -0.5 * eta**2 * score_estimate
prev_sample_mean = prev_sample_mean + log_term * dsigma
if prev_sample is None:
prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t
# log prob of prev_sample given prev_sample_mean and std_dev_t
log_prob = -((prev_sample.detach().to(torch.float32) - prev_sample_mean.to(torch.float32)) ** 2) / (
2 * (std_dev_t**2)
)
-math.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, pred_original_sample, log_prob
def ddim_step(
model_output: torch.Tensor,
latents: torch.Tensor,
eta: float,
sigmas: torch.Tensor,
index: int,
prev_sample: torch.Tensor,
):
model_output = convert_model_output(model_output, latents, sigmas, step_index=index)
prev_sample, prev_sample_mean, std_dev_t, dt_sqrt = ddim_update(
model_output,
sigmas.to(torch.float64),
index,
latents,
eta=eta,
)
# Compute log_prob
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * dt_sqrt) ** 2))
- torch.log(std_dev_t * dt_sqrt)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, model_output, log_prob
@dataclass
class DPMState:
order: int
model_outputs: List[torch.Tensor] = None
lower_order_nums = 0
def __post_init__(self):
self.model_outputs = [None] * self.order
def update(self, model_output: torch.Tensor):
for i in range(self.order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
def update_lower_order(self):
if self.lower_order_nums < self.order:
self.lower_order_nums += 1
def dpm_step(
order,
model_output: torch.Tensor,
sample: torch.Tensor,
step_index: int,
timesteps: list,
sigmas: torch.Tensor,
dpm_state: DPMState = None,
) -> torch.Tensor:
# Improve numerical stability for small number of steps
lower_order_final = step_index == len(timesteps) - 1
lower_order_second = (step_index == len(timesteps) - 2) and len(timesteps) < 15
model_output = convert_model_output(model_output, sample, sigmas, step_index=step_index)
assert dpm_state is not None
dpm_state.update(model_output)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if order == 1 or dpm_state.lower_order_nums < 1 or lower_order_final:
if step_index == 0 or lower_order_final:
prev_sample, _, _, _ = ddim_update(
model_output,
sigmas.to(torch.float64),
step_index,
sample,
eta=0.0,
)
else:
prev_sample = dpm_solver_first_order_update(
model_output,
sigmas.to(torch.float64),
step_index,
sample,
)
elif order == 2 or dpm_state.lower_order_nums < 2 or lower_order_second:
prev_sample = multistep_dpm_solver_second_order_update(
dpm_state.model_outputs,
sigmas.to(torch.float64),
step_index,
sample,
)
else:
assert False
dpm_state.update_lower_order()
# Cast sample back to expected dtype
prev_sample = prev_sample.to(model_output.dtype)
return prev_sample, model_output, None
def convert_model_output(
model_output,
sample,
sigmas,
step_index,
) -> torch.Tensor:
sigma_t = sigmas[step_index]
x0_pred = sample - sigma_t * model_output
return x0_pred
def ddim_update(
model_output: torch.Tensor,
sigmas,
step_index,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
eta: float = 1.0,
) -> torch.Tensor:
t, s = sigmas[step_index + 1], sigmas[step_index]
std_dev_t = eta * t
dt_sqrt = torch.sqrt(1.0 - t**2 * (1 - s) ** 2 / (s**2 * (1 - t) ** 2))
rho_t = std_dev_t * dt_sqrt
noise_pred = (sample - (1 - s) * model_output) / s
if noise is None:
noise = torch.randn_like(model_output)
prev_mean = (1 - t) * model_output + torch.sqrt(t**2 - rho_t**2) * noise_pred
x_t = prev_mean + rho_t * noise
return x_t, prev_mean, std_dev_t, dt_sqrt
def dpm_solver_first_order_update(
model_output: torch.Tensor,
sigmas,
step_index,
sample: torch.Tensor = None,
) -> torch.Tensor:
sigma_t, sigma_s = sigmas[step_index + 1], sigmas[step_index]
alpha_t, sigma_t = _sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = _sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
return x_t
def multistep_dpm_solver_second_order_update(
model_output_list: List[torch.Tensor],
sigmas,
step_index,
sample: torch.Tensor = None,
) -> torch.Tensor:
sigma_t, sigma_s0, sigma_s1 = (
sigmas[step_index + 1],
sigmas[step_index],
sigmas[step_index - 1],
)
alpha_t, sigma_t = _sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = _sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = _sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
)
return x_t
def _sigma_to_alpha_sigma_t(sigma):
alpha_t = 1 - sigma
sigma_t = sigma
return alpha_t, sigma_t