diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index b1913541..4518b0de 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -16,6 +16,7 @@ candle-datasets = { path = "../candle-datasets", version = "0.3.0" } candle-nn = { path = "../candle-nn", version = "0.3.0" } candle-transformers = { path = "../candle-transformers", version = "0.3.0" } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true } +candle-onnx = { path = "../candle-onnx", version = "0.3.0" } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true } diff --git a/candle-examples/examples/squeezenet-onnx/README.md b/candle-examples/examples/squeezenet-onnx/README.md new file mode 100644 index 00000000..fd705fb6 --- /dev/null +++ b/candle-examples/examples/squeezenet-onnx/README.md @@ -0,0 +1,10 @@ +## Using ONNX models in Candle + +This example demonstrates how to run ONNX based models in Candle, the model +being used here is a small sequeezenet variant. + +You can run the example with the following command: + +```bash +cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg +``` diff --git a/candle-examples/examples/squeezenet-onnx/main.rs b/candle-examples/examples/squeezenet-onnx/main.rs new file mode 100644 index 00000000..90a38bf0 --- /dev/null +++ b/candle-examples/examples/squeezenet-onnx/main.rs @@ -0,0 +1,57 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{IndexOp, D}; +use clap::Parser; + +#[derive(Parser)] +struct Args { + #[arg(long)] + image: String, + + #[arg(long)] + model: Option, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let image = candle_examples::imagenet::load_image224(args.image)?; + + println!("loaded image {image:?}"); + + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => hf_hub::api::sync::Api::new()? + .model("lmz/candle-onnx".into()) + .get("squeezenet1.1-7.onnx")?, + }; + + let model = candle_onnx::read_file(model)?; + let graph = model.graph.as_ref().unwrap(); + let mut inputs = std::collections::HashMap::new(); + inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?); + let mut outputs = candle_onnx::simple_eval(&model, inputs)?; + let logits = outputs.remove(&graph.output[0].name).unwrap(); + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + + // Sort the predictions and take the top 5 + let mut top: Vec<_> = prs.iter().enumerate().collect(); + top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + let top = top.into_iter().take(5).collect::>(); + + // Print the top predictions + for &(i, p) in &top { + println!( + "{:50}: {:.2}%", + candle_examples::imagenet::CLASSES[i], + p * 100.0 + ); + } + + Ok(()) +} diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs index 2c52e68e..43940596 100644 --- a/candle-onnx/examples/onnx_basics.rs +++ b/candle-onnx/examples/onnx_basics.rs @@ -35,33 +35,34 @@ pub fn main() -> Result<()> { } Command::SimpleEval { file } => { let model = candle_onnx::read_file(file)?; - let inputs = model - .graph - .as_ref() - .unwrap() - .input - .iter() - .map(|input| { - use candle_onnx::onnx::tensor_proto::DataType; + let graph = model.graph.as_ref().unwrap(); + let constants: std::collections::HashSet<_> = + graph.initializer.iter().map(|i| i.name.as_str()).collect(); + let mut inputs = std::collections::HashMap::new(); + for input in graph.input.iter() { + use candle_onnx::onnx::tensor_proto::DataType; + if constants.contains(input.name.as_str()) { + continue; + } - let type_ = input.r#type.as_ref().expect("no type for input"); - let type_ = type_.value.as_ref().expect("no type.value for input"); - let value = match type_ { - candle_onnx::onnx::type_proto::Value::TensorType(tt) => { - let dt = match DataType::try_from(tt.elem_type) { - Ok(dt) => match candle_onnx::dtype(dt) { - Some(dt) => dt, - None => { - anyhow::bail!( - "unsupported 'value' data-type {dt:?} for {}", - input.name - ) - } - }, - type_ => anyhow::bail!("unsupported input type {type_:?}"), - }; - let shape = tt.shape.as_ref().expect("no tensortype.shape for input"); - let dims = shape + let type_ = input.r#type.as_ref().expect("no type for input"); + let type_ = type_.value.as_ref().expect("no type.value for input"); + let value = match type_ { + candle_onnx::onnx::type_proto::Value::TensorType(tt) => { + let dt = match DataType::try_from(tt.elem_type) { + Ok(dt) => match candle_onnx::dtype(dt) { + Some(dt) => dt, + None => { + anyhow::bail!( + "unsupported 'value' data-type {dt:?} for {}", + input.name + ) + } + }, + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + let shape = tt.shape.as_ref().expect("no tensortype.shape for input"); + let dims = shape .dim .iter() .map(|dim| match dim.value.as_ref().expect("no dim value") { @@ -69,16 +70,16 @@ pub fn main() -> Result<()> { candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name), }) .collect::>>()?; - Tensor::zeros(dims, dt, &Device::Cpu)? - } - type_ => anyhow::bail!("unsupported input type {type_:?}"), - }; - Ok::<_, anyhow::Error>((input.name.clone(), value)) - }) - .collect::>()?; + Tensor::zeros(dims, dt, &Device::Cpu)? + } + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + println!("input {}: {value:?}", input.name); + inputs.insert(input.name.clone(), value); + } let outputs = candle_onnx::simple_eval(&model, inputs)?; for (name, value) in outputs.iter() { - println!("{name}: {value:?}") + println!("output {name}: {value:?}") } } } diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 4d44bd8e..c1c98101 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -382,7 +382,7 @@ pub fn simple_eval( Some([p]) => *p as usize, Some([p1, p2, p3, p4]) => { if p1 != p2 || p1 != p3 || p1 != p4 { - bail!("pads to be the same {pads:?} {}", node.name) + bail!("pads have to be the same {pads:?} {}", node.name) } *p1 as usize } @@ -396,7 +396,7 @@ pub fn simple_eval( Some([p1, p2]) => { if p1 != p2 { bail!( - "strides to be the same on both axis {pads:?} {}", + "strides have to be the same on both axis {pads:?} {}", node.name ) } @@ -412,7 +412,7 @@ pub fn simple_eval( Some([p1, p2]) => { if p1 != p2 { bail!( - "dilations to be the same on both axis {pads:?} {}", + "dilations have to be the same on both axis {pads:?} {}", node.name ) }