From 7aa4d93ec23adaff133bdb81e3d32b89001a8a08 Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Thu, 24 Apr 2025 11:20:58 -0300 Subject: [PATCH] fix: adjust timestep calculations for DDIM and TCD On img2img, the number of steps correspond to the last precalculated sigma values, but the internal alphas_cumprod and compvis_sigmas were being computed over the entire step range. Also, tweaks the prev_timestep calculation on DDIM to better match the current timestamp (like on TCD), to avoid inconsistencies due to rounding. --- denoiser.hpp | 18 ++++++++++++------ stable-diffusion.cpp | 11 +++++++++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/denoiser.hpp b/denoiser.hpp index 66799109..aeab3d64 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -474,6 +474,7 @@ static void sample_k_diffusion(sample_method_t method, ggml_context* work_ctx, ggml_tensor* x, std::vector sigmas, + int initial_step, std::shared_ptr rng, float eta) { size_t steps = sigmas.size() - 1; @@ -1060,10 +1061,14 @@ static void sample_k_diffusion(sample_method_t method, // x_t" // - pred_prev_sample -> "x_t-1" int timestep = - roundf(TIMESTEPS - - i * ((float)TIMESTEPS / steps)) - 1; + TIMESTEPS - 1 - + (int)roundf((initial_step + i) * + (TIMESTEPS / float(initial_step + steps))); // 1. get previous step value (=t-1) - int prev_timestep = timestep - TIMESTEPS / steps; + int prev_timestep = + TIMESTEPS - 1 - + (int)roundf((initial_step + i + 1) * + (TIMESTEPS / float(initial_step + steps))); // The sigma here is chosen to cause the // CompVisDenoiser to produce t = timestep float sigma = compvis_sigmas[timestep]; @@ -1236,12 +1241,13 @@ static void sample_k_diffusion(sample_method_t method, // Analytic form for TCD timesteps int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * - (int)floor(i * ((float)original_steps / steps)); + (int)floor((initial_step + i) * + ((float)original_steps / (initial_step + steps))); // 1. get previous step value int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * - (int)floor((i + 1) * - ((float)original_steps / steps)); + (int)floor((initial_step + i + 1) * + ((float)original_steps / (initial_step + steps))); // Here timestep_s is tau_n' in Algorithm 4. The _s // notation appears to be that from C. Lu, // "DPM-Solver: A Fast ODE Solver for Diffusion diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101..34141a22 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -798,6 +798,7 @@ class StableDiffusionGGML { float eta, sample_method_t method, const std::vector& sigmas, + int initial_step, int start_merge_step, SDCondition id_cond, std::vector skip_layers = {}, @@ -991,7 +992,7 @@ class StableDiffusionGGML { return denoised; }; - sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); + sample_k_diffusion(method, denoise, work_ctx, x, sigmas, initial_step, rng, eta); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); @@ -1202,6 +1203,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int height, enum sample_method_t sample_method, const std::vector& sigmas, + int initial_step, int64_t seed, int batch_count, const sd_image_t* control_cond, @@ -1464,6 +1466,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, eta, sample_method, sigmas, + initial_step, start_merge_step, id_cond, skip_layers, @@ -1611,6 +1614,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, height, sample_method, sigmas, + 0, seed, batch_count, control_cond, @@ -1775,8 +1779,9 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, if (t_enc == sample_steps) t_enc--; LOG_INFO("target t_enc is %zu steps", t_enc); + int initial_step = sample_steps - t_enc - 1; std::vector sigma_sched; - sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); + sigma_sched.assign(sigmas.begin() + initial_step, sigmas.end()); sd_image_t* result_images = generate_image(sd_ctx, work_ctx, @@ -1791,6 +1796,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, height, sample_method, sigma_sched, + initial_step, seed, batch_count, control_cond, @@ -1903,6 +1909,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, 0.f, sample_method, sigmas, + 0, -1, SDCondition(NULL, NULL, NULL));