From 4d14777673c51b66535d6d716991038a86e3448c Mon Sep 17 00:00:00 2001 From: NorilskMajor <112411678+NorilskMajor@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:49:04 +0900 Subject: [PATCH] Utilize batches in Stable Diffusion (#2071) * Utilize batches in Stable Diffusion that were already there, but unutilized. Also refactor out the `save_image` function. * Clippy + cosmetic fixes. --------- Co-authored-by: laurent --- .../examples/stable-diffusion/README.md | 3 +- .../examples/stable-diffusion/main.rs | 74 +++++++++++++++---- 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/candle-examples/examples/stable-diffusion/README.md b/candle-examples/examples/stable-diffusion/README.md index 1abaf330..bd79d012 100644 --- a/candle-examples/examples/stable-diffusion/README.md +++ b/candle-examples/examples/stable-diffusion/README.md @@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler. - `--cpu`: use the cpu rather than the gpu (much slower). - `--height`, `--width`: set the height and width for the generated image. - `--n-steps`: the number of steps to be used in the diffusion process. -- `--num-samples`: the number of samples to generate. +- `--num-samples`: the number of samples to generate iteratively. +- `--bsize`: the numbers of samples to generate simultaneously. - `--final-image`: the filename for the generated image(s). ### Using flash-attention diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 0e39902b..d424444b 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion; use anyhow::{Error as E, Result}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; +use stable_diffusion::vae::AutoEncoderKL; use tokenizers::Tokenizer; #[derive(Parser)] @@ -64,9 +65,13 @@ struct Args { #[arg(long)] n_steps: Option, - /// The number of samples to generate. + /// The number of samples to generate iteratively. #[arg(long, default_value_t = 1)] - num_samples: i64, + num_samples: usize, + + /// The numbers of samples to generate simultaneously. + #[arg[long, default_value_t = 1]] + bsize: usize, /// The name of the final image to generate. #[arg(long, value_name = "FILE", default_value = "sd_final.png")] @@ -236,8 +241,8 @@ impl ModelFile { fn output_filename( basename: &str, - sample_idx: i64, - num_samples: i64, + sample_idx: usize, + num_samples: usize, timestep_idx: Option, ) -> String { let filename = if num_samples > 1 { @@ -261,6 +266,33 @@ fn output_filename( } } +#[allow(clippy::too_many_arguments)] +fn save_image( + vae: &AutoEncoderKL, + latents: &Tensor, + vae_scale: f64, + bsize: usize, + idx: usize, + final_image: &str, + num_samples: usize, + timestep_ids: Option, +) -> Result<()> { + let images = vae.decode(&(latents / vae_scale)?)?; + let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?; + for batch in 0..bsize { + let image = images.i(batch)?; + let image_filename = output_filename( + final_image, + (bsize * idx) + batch + 1, + batch + num_samples, + timestep_ids, + ); + candle_examples::save_image(&image, image_filename)?; + } + Ok(()) +} + #[allow(clippy::too_many_arguments)] fn text_embeddings( prompt: &str, @@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> { final_image, sliced_attention_size, num_samples, + bsize, sd_version, clip_weights, vae_weights, @@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> { .collect::>>()?; let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?; println!("{text_embeddings:?}"); println!("Building the autoencoder."); @@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> { } else { 0 }; - let bsize = 1; let vae_scale = match sd_version { StableDiffusionVersion::V1_5 @@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> { println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); if args.intermediary_images { - let image = vae.decode(&(&latents / vae_scale)?)?; - let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; - let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; - let image_filename = - output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); - candle_examples::save_image(&image, image_filename)? + save_image( + &vae, + &latents, + vae_scale, + bsize, + idx, + &final_image, + num_samples, + Some(timestep_index + 1), + )?; } } @@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> { idx + 1, num_samples ); - let image = vae.decode(&(&latents / vae_scale)?)?; - let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; - let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; - let image_filename = output_filename(&final_image, idx + 1, num_samples, None); - candle_examples::save_image(&image, image_filename)? + save_image( + &vae, + &latents, + vae_scale, + bsize, + idx, + &final_image, + num_samples, + None, + )?; } Ok(()) }