Fix the W clip embeddings. (#887)
* Fix the W clip embeddings. * Add the specialized ddpm scheduler.
This commit is contained in:
parent
7dd8e12472
commit
5082954c52
|
@ -193,9 +193,9 @@ fn encode_prompt(
|
|||
println!("Building the clip transformer.");
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
|
||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DDPMWSchedulerConfig {
|
||||
scaler: f64,
|
||||
s: f64,
|
||||
}
|
||||
|
||||
impl Default for DDPMWSchedulerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scaler: 1f64,
|
||||
s: 0.008f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DDPMWScheduler {
|
||||
init_alpha_cumprod: f64,
|
||||
init_noise_sigma: f64,
|
||||
timesteps: Vec<f64>,
|
||||
pub config: DDPMWSchedulerConfig,
|
||||
}
|
||||
|
||||
impl DDPMWScheduler {
|
||||
pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
|
||||
let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
|
||||
.cos()
|
||||
.powi(2);
|
||||
let timesteps = (0..=inference_steps)
|
||||
.map(|i| 1. - i as f64 / inference_steps as f64)
|
||||
.collect::<Vec<_>>();
|
||||
Ok(Self {
|
||||
init_alpha_cumprod,
|
||||
init_noise_sigma: 1.0,
|
||||
timesteps,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn alpha_cumprod(&self, t: f64) -> f64 {
|
||||
let scaler = self.config.scaler;
|
||||
let s = self.config.s;
|
||||
let t = if scaler > 1. {
|
||||
1. - (1. - t).powf(scaler)
|
||||
} else if scaler < 1. {
|
||||
t.powf(scaler)
|
||||
} else {
|
||||
t
|
||||
};
|
||||
let alpha_cumprod =
|
||||
((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod;
|
||||
alpha_cumprod.clamp(0.0001, 0.9999)
|
||||
}
|
||||
|
||||
fn previous_timestep(&self, ts: f64) -> f64 {
|
||||
let index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (idx, (v - ts).abs()))
|
||||
.min_by(|x, y| x.1.total_cmp(&y.1))
|
||||
.unwrap()
|
||||
.0;
|
||||
self.timesteps[index + 1]
|
||||
}
|
||||
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
|
||||
sample
|
||||
}
|
||||
|
||||
pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
|
||||
let prev_t = self.previous_timestep(ts);
|
||||
|
||||
let alpha_cumprod = self.alpha_cumprod(ts);
|
||||
let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
|
||||
let alpha = alpha_cumprod / alpha_cumprod_prev;
|
||||
|
||||
let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
|
||||
let mu = (mu * (1. / alpha).sqrt())?;
|
||||
|
||||
let std_noise = mu.randn_like(0., 1.)?;
|
||||
let std =
|
||||
std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
|
||||
if prev_t == 0. {
|
||||
Ok(mu)
|
||||
} else {
|
||||
mu + std
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
pub mod common;
|
||||
pub mod ddpm;
|
||||
pub mod diffnext;
|
||||
pub mod paella_vq;
|
||||
pub mod prior;
|
||||
|
|
Loading…
Reference in New Issue