Fixes for the stable diffusion example. (#342)

* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
This commit is contained in:
Laurent Mazare 2023-08-08 15:57:09 +02:00 committed by GitHub
parent ab35684326
commit 89d3926c9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 12 deletions

View File

@ -29,7 +29,7 @@ pub struct Config {
embed_dim: usize, // aka config.hidden_size
activation: Activation, // aka config.hidden_act
intermediate_size: usize,
max_position_embeddings: usize,
pub max_position_embeddings: usize,
// The character to use for padding, use EOS when not set.
pad_with: Option<String>,
num_hidden_layers: usize,
@ -90,7 +90,7 @@ impl ClipTextEmbeddings {
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,

View File

@ -49,7 +49,7 @@ impl Timesteps {
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);
let emb = if self.flip_sin_to_cos {
Tensor::cat(&[&cos, &sin], D::Minus1)?

View File

@ -181,19 +181,29 @@ fn run(args: Args) -> Result<()> {
let device = candle_examples::device(cpu)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match tokenizer.get_padding() {
Some(padding) => padding.pad_id,
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let tokens = tokenizer
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
let uncond_tokens = tokenizer
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
@ -202,6 +212,7 @@ fn run(args: Args) -> Result<()> {
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
println!("text-embeddings: {text_embeddings:?}");
println!("Building the autoencoder.");
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");

View File

@ -118,7 +118,7 @@ impl ResnetBlock2D {
.forward(&nn::ops::silu(temb)?)?
.unsqueeze(D::Minus1)?
.unsqueeze(D::Minus1)?
.add(&xs)?,
.broadcast_add(&xs)?,
_ => xs,
};
let xs = self

View File

@ -59,17 +59,21 @@ impl GroupNorm {
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let mut w_dims = vec![1; x_shape.len()];
w_dims[1] = n_channels;
let weight = self.weight.reshape(w_dims.clone())?;
let bias = self.bias.reshape(w_dims)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?
.reshape(x_shape)
.reshape(x_shape)?
.broadcast_mul(&weight)?
.broadcast_add(&bias)
}
}
pub fn group_norm(
num_channels: usize,
num_groups: usize,
num_channels: usize,
eps: f64,
vb: crate::VarBuilder,
) -> Result<GroupNorm> {

View File

@ -30,8 +30,8 @@ use test_utils::to_vec3_round;
#[test]
fn group_norm() -> Result<()> {
let device = &Device::Cpu;
let w = Tensor::new(&[1f32], device)?;
let b = Tensor::new(&[0f32], device)?;
let w = Tensor::from_vec(vec![1f32; 6], 6, device)?;
let b = Tensor::from_vec(vec![0f32; 6], 6, device)?;
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;