Pref-Restoration / DiffusionNFT / flow_grpo / diffusers_patch / solver.py
solver.py
Raw
import math
import torch
from diffusers.utils.torch_utils import randn_tensor
from typing import Optional, List
from dataclasses import dataclass
import torch.distributed as dist
import tqdm
from functools import partial

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)


# Modified from MixGRPO
def run_sampling(
    v_pred_fn,
    z,
    sigma_schedule,
    solver="flow",
    determistic=False,
    eta=0.7,
    pred_latents: Optional[torch.Tensor] = None,
):
    assert solver in ["flow", "dance", "ddim", "dpm1", "dpm2"]
    dtype = z.dtype
    all_latents = [z]
    all_log_probs = []
    
    if "dpm" in solver:
        order = int(solver[-1])
        dpm_state = DPMState(order=order)
    for i in tqdm(
        range(len(sigma_schedule) - 1),
        desc="Sampling Progress",
        disable=not dist.is_initialized() or dist.get_rank() != 0,
    ):
        sigma = sigma_schedule[i]
        
        pred = v_pred_fn(z.to(dtype), sigma, pred_latents)
        if solver == "flow":
            z, pred_original, log_prob = flow_grpo_step(
                model_output=pred.float(),
                latents=z.float(),
                eta=eta if not determistic else 0,
                sigmas=sigma_schedule,
                index=i,
                prev_sample=None,
            )
        elif solver == "dance":
            z, pred_original, log_prob = dance_grpo_step(
                pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None
            )
        elif solver == "ddim":
            z, pred_original, log_prob = ddim_step(
                pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None
            )
        elif "dpm" in solver:
            assert determistic
            # Ensure sigma_schedule is a PyTorch tensor
            import numpy as np
            if isinstance(sigma_schedule, np.ndarray):
                sigma_schedule = torch.from_numpy(sigma_schedule).to(z.device)
            z, pred_original, log_prob = dpm_step(
                order,
                model_output=pred.float(),
                sample=z.float(),
                step_index=i,
                timesteps=sigma_schedule[:-1],
                sigmas=sigma_schedule,
                dpm_state=dpm_state,
            )
        else:
            assert False
        z = z.to(dtype) # z torch.Size([B, 32, 16, 16])
        all_latents.append(z)
        all_log_probs.append(log_prob) # log_prob None
        
    latents = z.to(dtype)
    ###############################################################################################
    # import ipdb; ipdb.set_trace()
    ###############################################################################################
    # all_latents = torch.stack(all_latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64)
    # all_log_probs = torch.stack(all_log_probs, dim=1)  # (batch_size, num_steps, 1)
    return latents, all_latents, all_log_probs


def flow_grpo_step(
    model_output: torch.Tensor,
    latents: torch.Tensor,
    eta: float,
    sigmas: torch.Tensor,
    index: int,
    prev_sample: torch.Tensor,
    generator: Optional[torch.Generator] = None,
):
    device = model_output.device
    sigma = sigmas[index].to(device)
    sigma_prev = sigmas[index + 1].to(device)
    sigma_max = sigmas[1].item()
    dt = sigma_prev - sigma  # neg dt

    pred_original_sample = latents - sigma * model_output

    std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * eta

    if prev_sample is not None and generator is not None:
        raise ValueError(
            "Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
            " `prev_sample` stays `None`."
        )

    prev_sample_mean = (
        latents * (1 + std_dev_t**2 / (2 * sigma) * dt)
        + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
    )

    if prev_sample is None:
        variance_noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
        prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise

    log_prob = (
        -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
        - torch.log(std_dev_t * torch.sqrt(-1 * dt))
        - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
    )

    # mean along all but batch dimension
    log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))

    return prev_sample, pred_original_sample, log_prob


def dance_grpo_step(
    model_output: torch.Tensor,
    latents: torch.Tensor,
    eta: float,
    sigmas: torch.Tensor,
    index: int,
    prev_sample: torch.Tensor,
):
    sigma = sigmas[index]
    dsigma = sigmas[index + 1] - sigma  # neg dt
    prev_sample_mean = latents + dsigma * model_output

    pred_original_sample = latents - sigma * model_output

    delta_t = sigma - sigmas[index + 1]  # pos -dt
    std_dev_t = eta * math.sqrt(delta_t)

    score_estimate = -(latents - pred_original_sample * (1 - sigma)) / sigma**2
    log_term = -0.5 * eta**2 * score_estimate
    prev_sample_mean = prev_sample_mean + log_term * dsigma

    if prev_sample is None:
        prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t

    # log prob of prev_sample given prev_sample_mean and std_dev_t
    log_prob = -((prev_sample.detach().to(torch.float32) - prev_sample_mean.to(torch.float32)) ** 2) / (
        2 * (std_dev_t**2)
    )
    -math.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))

    # mean along all but batch dimension
    log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
    return prev_sample, pred_original_sample, log_prob


def ddim_step(
    model_output: torch.Tensor,
    latents: torch.Tensor,
    eta: float,
    sigmas: torch.Tensor,
    index: int,
    prev_sample: torch.Tensor,
):
    model_output = convert_model_output(model_output, latents, sigmas, step_index=index)
    prev_sample, prev_sample_mean, std_dev_t, dt_sqrt = ddim_update(
        model_output,
        sigmas.to(torch.float64),
        index,
        latents,
        eta=eta,
    )

    # Compute log_prob
    log_prob = (
        -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * dt_sqrt) ** 2))
        - torch.log(std_dev_t * dt_sqrt)
        - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
    )

    # mean along all but batch dimension
    log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
    return prev_sample, model_output, log_prob


@dataclass
class DPMState:
    order: int
    model_outputs: List[torch.Tensor] = None
    lower_order_nums = 0

    def __post_init__(self):
        self.model_outputs = [None] * self.order

    def update(self, model_output: torch.Tensor):
        for i in range(self.order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
        self.model_outputs[-1] = model_output

    def update_lower_order(self):
        if self.lower_order_nums < self.order:
            self.lower_order_nums += 1


def dpm_step(
    order,
    model_output: torch.Tensor,
    sample: torch.Tensor,
    step_index: int,
    timesteps: list,
    sigmas: torch.Tensor,
    dpm_state: DPMState = None,
) -> torch.Tensor:

    # Improve numerical stability for small number of steps
    lower_order_final = step_index == len(timesteps) - 1
    lower_order_second = (step_index == len(timesteps) - 2) and len(timesteps) < 15

    model_output = convert_model_output(model_output, sample, sigmas, step_index=step_index)

    assert dpm_state is not None
    dpm_state.update(model_output)

    # Upcast to avoid precision issues when computing prev_sample
    sample = sample.to(torch.float32)

    if order == 1 or dpm_state.lower_order_nums < 1 or lower_order_final:
        if step_index == 0 or lower_order_final:
            prev_sample, _, _, _ = ddim_update(
                model_output,
                sigmas.to(torch.float64),
                step_index,
                sample,
                eta=0.0,
            )
        else:
            prev_sample = dpm_solver_first_order_update(
                model_output,
                sigmas.to(torch.float64),
                step_index,
                sample,
            )
    elif order == 2 or dpm_state.lower_order_nums < 2 or lower_order_second:
        prev_sample = multistep_dpm_solver_second_order_update(
            dpm_state.model_outputs,
            sigmas.to(torch.float64),
            step_index,
            sample,
        )
    else:
        assert False

    dpm_state.update_lower_order()

    # Cast sample back to expected dtype
    prev_sample = prev_sample.to(model_output.dtype)

    return prev_sample, model_output, None


def convert_model_output(
    model_output,
    sample,
    sigmas,
    step_index,
) -> torch.Tensor:
    sigma_t = sigmas[step_index]
    x0_pred = sample - sigma_t * model_output

    return x0_pred


def ddim_update(
    model_output: torch.Tensor,
    sigmas,
    step_index,
    sample: torch.Tensor = None,
    noise: Optional[torch.Tensor] = None,
    eta: float = 1.0,
) -> torch.Tensor:

    t, s = sigmas[step_index + 1], sigmas[step_index]

    std_dev_t = eta * t
    dt_sqrt = torch.sqrt(1.0 - t**2 * (1 - s) ** 2 / (s**2 * (1 - t) ** 2))
    rho_t = std_dev_t * dt_sqrt
    noise_pred = (sample - (1 - s) * model_output) / s
    if noise is None:
        noise = torch.randn_like(model_output)
    prev_mean = (1 - t) * model_output + torch.sqrt(t**2 - rho_t**2) * noise_pred
    x_t = prev_mean + rho_t * noise

    return x_t, prev_mean, std_dev_t, dt_sqrt


def dpm_solver_first_order_update(
    model_output: torch.Tensor,
    sigmas,
    step_index,
    sample: torch.Tensor = None,
) -> torch.Tensor:

    sigma_t, sigma_s = sigmas[step_index + 1], sigmas[step_index]
    alpha_t, sigma_t = _sigma_to_alpha_sigma_t(sigma_t)
    alpha_s, sigma_s = _sigma_to_alpha_sigma_t(sigma_s)
    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
    lambda_s = torch.log(alpha_s) - torch.log(sigma_s)

    h = lambda_t - lambda_s
    x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output

    return x_t


def multistep_dpm_solver_second_order_update(
    model_output_list: List[torch.Tensor],
    sigmas,
    step_index,
    sample: torch.Tensor = None,
) -> torch.Tensor:

    sigma_t, sigma_s0, sigma_s1 = (
        sigmas[step_index + 1],
        sigmas[step_index],
        sigmas[step_index - 1],
    )

    alpha_t, sigma_t = _sigma_to_alpha_sigma_t(sigma_t)
    alpha_s0, sigma_s0 = _sigma_to_alpha_sigma_t(sigma_s0)
    alpha_s1, sigma_s1 = _sigma_to_alpha_sigma_t(sigma_s1)

    lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
    lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
    lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)

    m0, m1 = model_output_list[-1], model_output_list[-2]

    h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
    r0 = h_0 / h
    D0, D1 = m0, (1.0 / r0) * (m0 - m1)

    x_t = (
        (sigma_t / sigma_s0) * sample
        - (alpha_t * (torch.exp(-h) - 1.0)) * D0
        - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
    )

    return x_t


def _sigma_to_alpha_sigma_t(sigma):
    alpha_t = 1 - sigma
    sigma_t = sigma
    return alpha_t, sigma_t