Utilize batches in Stable Diffusion (#2071)

* Utilize batches in Stable Diffusion that were already there, but unutilized.

Also refactor out the `save_image` function.

* Clippy + cosmetic fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
NorilskMajor 2024-04-16 13:49:04 +09:00 committed by GitHub
parent f135b7963d
commit 4d14777673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 17 deletions

View File

@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
- `--cpu`: use the cpu rather than the gpu (much slower). - `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image. - `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process. - `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate. - `--num-samples`: the number of samples to generate iteratively.
- `--bsize`: the numbers of samples to generate simultaneously.
- `--final-image`: the filename for the generated image(s). - `--final-image`: the filename for the generated image(s).
### Using flash-attention ### Using flash-attention

View File

@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser; use clap::Parser;
use stable_diffusion::vae::AutoEncoderKL;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
#[derive(Parser)] #[derive(Parser)]
@ -64,9 +65,13 @@ struct Args {
#[arg(long)] #[arg(long)]
n_steps: Option<usize>, n_steps: Option<usize>,
/// The number of samples to generate. /// The number of samples to generate iteratively.
#[arg(long, default_value_t = 1)] #[arg(long, default_value_t = 1)]
num_samples: i64, num_samples: usize,
/// The numbers of samples to generate simultaneously.
#[arg[long, default_value_t = 1]]
bsize: usize,
/// The name of the final image to generate. /// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")] #[arg(long, value_name = "FILE", default_value = "sd_final.png")]
@ -236,8 +241,8 @@ impl ModelFile {
fn output_filename( fn output_filename(
basename: &str, basename: &str,
sample_idx: i64, sample_idx: usize,
num_samples: i64, num_samples: usize,
timestep_idx: Option<usize>, timestep_idx: Option<usize>,
) -> String { ) -> String {
let filename = if num_samples > 1 { let filename = if num_samples > 1 {
@ -261,6 +266,33 @@ fn output_filename(
} }
} }
#[allow(clippy::too_many_arguments)]
fn save_image(
vae: &AutoEncoderKL,
latents: &Tensor,
vae_scale: f64,
bsize: usize,
idx: usize,
final_image: &str,
num_samples: usize,
timestep_ids: Option<usize>,
) -> Result<()> {
let images = vae.decode(&(latents / vae_scale)?)?;
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
for batch in 0..bsize {
let image = images.i(batch)?;
let image_filename = output_filename(
final_image,
(bsize * idx) + batch + 1,
batch + num_samples,
timestep_ids,
);
candle_examples::save_image(&image, image_filename)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn text_embeddings( fn text_embeddings(
prompt: &str, prompt: &str,
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
final_image, final_image,
sliced_attention_size, sliced_attention_size,
num_samples, num_samples,
bsize,
sd_version, sd_version,
clip_weights, clip_weights,
vae_weights, vae_weights,
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
println!("{text_embeddings:?}"); println!("{text_embeddings:?}");
println!("Building the autoencoder."); println!("Building the autoencoder.");
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
} else { } else {
0 0
}; };
let bsize = 1;
let vae_scale = match sd_version { let vae_scale = match sd_version {
StableDiffusionVersion::V1_5 StableDiffusionVersion::V1_5
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images { if args.intermediary_images {
let image = vae.decode(&(&latents / vae_scale)?)?; save_image(
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; &vae,
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; &latents,
let image_filename = vae_scale,
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); bsize,
candle_examples::save_image(&image, image_filename)? idx,
&final_image,
num_samples,
Some(timestep_index + 1),
)?;
} }
} }
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
idx + 1, idx + 1,
num_samples num_samples
); );
let image = vae.decode(&(&latents / vae_scale)?)?; save_image(
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; &vae,
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; &latents,
let image_filename = output_filename(&final_image, idx + 1, num_samples, None); vae_scale,
candle_examples::save_image(&image, image_filename)? bsize,
idx,
&final_image,
num_samples,
None,
)?;
} }
Ok(()) Ok(())
} }