Utilize batches in Stable Diffusion (#2071)
* Utilize batches in Stable Diffusion that were already there, but unutilized. Also refactor out the `save_image` function. * Clippy + cosmetic fixes. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
f135b7963d
commit
4d14777673
|
@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
|
||||||
- `--cpu`: use the cpu rather than the gpu (much slower).
|
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||||
- `--height`, `--width`: set the height and width for the generated image.
|
- `--height`, `--width`: set the height and width for the generated image.
|
||||||
- `--n-steps`: the number of steps to be used in the diffusion process.
|
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||||
- `--num-samples`: the number of samples to generate.
|
- `--num-samples`: the number of samples to generate iteratively.
|
||||||
|
- `--bsize`: the numbers of samples to generate simultaneously.
|
||||||
- `--final-image`: the filename for the generated image(s).
|
- `--final-image`: the filename for the generated image(s).
|
||||||
|
|
||||||
### Using flash-attention
|
### Using flash-attention
|
||||||
|
|
|
@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use stable_diffusion::vae::AutoEncoderKL;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
|
@ -64,9 +65,13 @@ struct Args {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
n_steps: Option<usize>,
|
n_steps: Option<usize>,
|
||||||
|
|
||||||
/// The number of samples to generate.
|
/// The number of samples to generate iteratively.
|
||||||
#[arg(long, default_value_t = 1)]
|
#[arg(long, default_value_t = 1)]
|
||||||
num_samples: i64,
|
num_samples: usize,
|
||||||
|
|
||||||
|
/// The numbers of samples to generate simultaneously.
|
||||||
|
#[arg[long, default_value_t = 1]]
|
||||||
|
bsize: usize,
|
||||||
|
|
||||||
/// The name of the final image to generate.
|
/// The name of the final image to generate.
|
||||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||||
|
@ -236,8 +241,8 @@ impl ModelFile {
|
||||||
|
|
||||||
fn output_filename(
|
fn output_filename(
|
||||||
basename: &str,
|
basename: &str,
|
||||||
sample_idx: i64,
|
sample_idx: usize,
|
||||||
num_samples: i64,
|
num_samples: usize,
|
||||||
timestep_idx: Option<usize>,
|
timestep_idx: Option<usize>,
|
||||||
) -> String {
|
) -> String {
|
||||||
let filename = if num_samples > 1 {
|
let filename = if num_samples > 1 {
|
||||||
|
@ -261,6 +266,33 @@ fn output_filename(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn save_image(
|
||||||
|
vae: &AutoEncoderKL,
|
||||||
|
latents: &Tensor,
|
||||||
|
vae_scale: f64,
|
||||||
|
bsize: usize,
|
||||||
|
idx: usize,
|
||||||
|
final_image: &str,
|
||||||
|
num_samples: usize,
|
||||||
|
timestep_ids: Option<usize>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let images = vae.decode(&(latents / vae_scale)?)?;
|
||||||
|
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
|
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
|
||||||
|
for batch in 0..bsize {
|
||||||
|
let image = images.i(batch)?;
|
||||||
|
let image_filename = output_filename(
|
||||||
|
final_image,
|
||||||
|
(bsize * idx) + batch + 1,
|
||||||
|
batch + num_samples,
|
||||||
|
timestep_ids,
|
||||||
|
);
|
||||||
|
candle_examples::save_image(&image, image_filename)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn text_embeddings(
|
fn text_embeddings(
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
|
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
|
||||||
final_image,
|
final_image,
|
||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
num_samples,
|
num_samples,
|
||||||
|
bsize,
|
||||||
sd_version,
|
sd_version,
|
||||||
clip_weights,
|
clip_weights,
|
||||||
vae_weights,
|
vae_weights,
|
||||||
|
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||||
|
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
|
||||||
println!("{text_embeddings:?}");
|
println!("{text_embeddings:?}");
|
||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
|
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
|
||||||
} else {
|
} else {
|
||||||
0
|
0
|
||||||
};
|
};
|
||||||
let bsize = 1;
|
|
||||||
|
|
||||||
let vae_scale = match sd_version {
|
let vae_scale = match sd_version {
|
||||||
StableDiffusionVersion::V1_5
|
StableDiffusionVersion::V1_5
|
||||||
|
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
|
||||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||||
|
|
||||||
if args.intermediary_images {
|
if args.intermediary_images {
|
||||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
save_image(
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
&vae,
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
&latents,
|
||||||
let image_filename =
|
vae_scale,
|
||||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
bsize,
|
||||||
candle_examples::save_image(&image, image_filename)?
|
idx,
|
||||||
|
&final_image,
|
||||||
|
num_samples,
|
||||||
|
Some(timestep_index + 1),
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
|
||||||
idx + 1,
|
idx + 1,
|
||||||
num_samples
|
num_samples
|
||||||
);
|
);
|
||||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
save_image(
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
&vae,
|
||||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
&latents,
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
vae_scale,
|
||||||
candle_examples::save_image(&image, image_filename)?
|
bsize,
|
||||||
|
idx,
|
||||||
|
&final_image,
|
||||||
|
num_samples,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue