Support the flux-dev model too. (#2395)

This commit is contained in:
Laurent Mazare 2024-08-04 11:16:24 +01:00 committed by GitHub
parent c0a559d427
commit 89eae41efd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 37 additions and 9 deletions

View File

@ -37,6 +37,15 @@ struct Args {
#[arg(long)]
decode_only: Option<String>,
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum Model {
Schnell,
Dev,
}
fn run(args: Args) -> Result<()> {
@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> {
width,
tracing,
decode_only,
model,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> {
};
let api = hf_hub::api::sync::Api::new()?;
let bf_repo = api.repo(hf_hub::Repo::model(
"black-forest-labs/FLUX.1-schnell".to_string(),
));
let bf_repo = {
let name = match model {
Model::Dev => "black-forest-labs/FLUX.1-dev",
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = bf_repo.get("flux1-schnell.sft")?;
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
Model::Dev => bf_repo.get("flux1-dev.sft")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::model::Config::schnell();
let model = flux::model::Flux::new(&cfg, vb)?;
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}");
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> {
let img = {
let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::autoencoder::Config::schnell();
let cfg = match model {
Model::Dev => flux::autoencoder::Config::dev(),
Model::Schnell => flux::autoencoder::Config::schnell(),
};
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)?
};