Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example. * Bugfix. * Another fix. * Fix for group-norm. * More fixes to get SD to work.
This commit is contained in:
parent
ab35684326
commit
89d3926c9b
|
@ -29,7 +29,7 @@ pub struct Config {
|
|||
embed_dim: usize, // aka config.hidden_size
|
||||
activation: Activation, // aka config.hidden_act
|
||||
intermediate_size: usize,
|
||||
max_position_embeddings: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
// The character to use for padding, use EOS when not set.
|
||||
pad_with: Option<String>,
|
||||
num_hidden_layers: usize,
|
||||
|
@ -90,7 +90,7 @@ impl ClipTextEmbeddings {
|
|||
vs.pp("position_embedding"),
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||
Ok(ClipTextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
|
|
|
@ -49,7 +49,7 @@ impl Timesteps {
|
|||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||
let emb = exponent.exp()?;
|
||||
// emb = timesteps[:, None].float() * emb[None, :]
|
||||
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
|
||||
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||
let emb = if self.flip_sin_to_cos {
|
||||
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
||||
|
|
|
@ -181,19 +181,29 @@ fn run(args: Args) -> Result<()> {
|
|||
let device = candle_examples::device(cpu)?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match tokenizer.get_padding() {
|
||||
Some(padding) => padding.pad_id,
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let tokens = tokenizer
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
let uncond_tokens = tokenizer
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
|
@ -202,6 +212,7 @@ fn run(args: Args) -> Result<()> {
|
|||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||
|
||||
println!("text-embeddings: {text_embeddings:?}");
|
||||
println!("Building the autoencoder.");
|
||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||
println!("Building the unet.");
|
||||
|
|
|
@ -118,7 +118,7 @@ impl ResnetBlock2D {
|
|||
.forward(&nn::ops::silu(temb)?)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.add(&xs)?,
|
||||
.broadcast_add(&xs)?,
|
||||
_ => xs,
|
||||
};
|
||||
let xs = self
|
||||
|
|
|
@ -59,17 +59,21 @@ impl GroupNorm {
|
|||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let mut w_dims = vec![1; x_shape.len()];
|
||||
w_dims[1] = n_channels;
|
||||
let weight = self.weight.reshape(w_dims.clone())?;
|
||||
let bias = self.bias.reshape(w_dims)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?
|
||||
.reshape(x_shape)
|
||||
.reshape(x_shape)?
|
||||
.broadcast_mul(&weight)?
|
||||
.broadcast_add(&bias)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn group_norm(
|
||||
num_channels: usize,
|
||||
num_groups: usize,
|
||||
num_channels: usize,
|
||||
eps: f64,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GroupNorm> {
|
||||
|
|
|
@ -30,8 +30,8 @@ use test_utils::to_vec3_round;
|
|||
#[test]
|
||||
fn group_norm() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
let w = Tensor::new(&[1f32], device)?;
|
||||
let b = Tensor::new(&[0f32], device)?;
|
||||
let w = Tensor::from_vec(vec![1f32; 6], 6, device)?;
|
||||
let b = Tensor::from_vec(vec![0f32; 6], 6, device)?;
|
||||
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
|
||||
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;
|
||||
|
||||
|
|
Loading…
Reference in New Issue