Support the flux-dev model too. (#2395)
This commit is contained in:
parent
c0a559d427
commit
89eae41efd
|
@ -37,6 +37,15 @@ struct Args {
|
|||
|
||||
#[arg(long)]
|
||||
decode_only: Option<String>,
|
||||
|
||||
#[arg(long, value_enum, default_value = "schnell")]
|
||||
model: Model,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum Model {
|
||||
Schnell,
|
||||
Dev,
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
|
@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> {
|
|||
width,
|
||||
tracing,
|
||||
decode_only,
|
||||
model,
|
||||
} = args;
|
||||
let width = width.unwrap_or(1360);
|
||||
let height = height.unwrap_or(768);
|
||||
|
@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> {
|
|||
};
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let bf_repo = api.repo(hf_hub::Repo::model(
|
||||
"black-forest-labs/FLUX.1-schnell".to_string(),
|
||||
));
|
||||
let bf_repo = {
|
||||
let name = match model {
|
||||
Model::Dev => "black-forest-labs/FLUX.1-dev",
|
||||
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
|
||||
};
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let img = match decode_only {
|
||||
|
@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> {
|
|||
};
|
||||
println!("CLIP\n{clip_emb}");
|
||||
let img = {
|
||||
let model_file = bf_repo.get("flux1-schnell.sft")?;
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.sft")?,
|
||||
};
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = flux::model::Config::schnell();
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::model::Config::dev(),
|
||||
Model::Schnell => flux::model::Config::schnell(),
|
||||
};
|
||||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
||||
let timesteps = match model {
|
||||
Model::Dev => {
|
||||
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
|
||||
}
|
||||
Model::Schnell => flux::sampling::get_schedule(4, None),
|
||||
};
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
|
||||
println!("{state:?}");
|
||||
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
|
||||
println!("{timesteps:?}");
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
|
@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> {
|
|||
let img = {
|
||||
let model_file = bf_repo.get("ae.sft")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = flux::autoencoder::Config::schnell();
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::autoencoder::Config::dev(),
|
||||
Model::Schnell => flux::autoencoder::Config::schnell(),
|
||||
};
|
||||
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
|
||||
model.decode(&img)?
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue