From f365a075e551dd50f7def29ecc2d8cba100c4625 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 5 Nov 2023 16:57:26 +0100 Subject: [PATCH] Add more models to the onnx example. (#1273) * Add more models to the onnx example. * Input validation. * Input validation. * Bugfix. * Implement clip. * BatchNorm support. * Get the efficientnet onnx to work. --- .../{squeezenet-onnx => onnx}/README.md | 0 .../{squeezenet-onnx => onnx}/main.rs | 37 +++- candle-onnx/src/eval.rs | 167 ++++++++++++++++-- 3 files changed, 181 insertions(+), 23 deletions(-) rename candle-examples/examples/{squeezenet-onnx => onnx}/README.md (100%) rename candle-examples/examples/{squeezenet-onnx => onnx}/main.rs (55%) diff --git a/candle-examples/examples/squeezenet-onnx/README.md b/candle-examples/examples/onnx/README.md similarity index 100% rename from candle-examples/examples/squeezenet-onnx/README.md rename to candle-examples/examples/onnx/README.md diff --git a/candle-examples/examples/squeezenet-onnx/main.rs b/candle-examples/examples/onnx/main.rs similarity index 55% rename from candle-examples/examples/squeezenet-onnx/main.rs rename to candle-examples/examples/onnx/main.rs index 90a38bf0..d3b0f8f8 100644 --- a/candle-examples/examples/squeezenet-onnx/main.rs +++ b/candle-examples/examples/onnx/main.rs @@ -5,7 +5,13 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{IndexOp, D}; -use clap::Parser; +use clap::{Parser, ValueEnum}; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + SqueezeNet, + EfficientNet, +} #[derive(Parser)] struct Args { @@ -14,19 +20,32 @@ struct Args { #[arg(long)] model: Option, + + /// The model to be used. + #[arg(value_enum, long, default_value_t = Which::SqueezeNet)] + which: Which, } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let image = candle_examples::imagenet::load_image224(args.image)?; + let image = match args.which { + Which::SqueezeNet => image, + Which::EfficientNet => image.permute((1, 2, 0))?, + }; println!("loaded image {image:?}"); let model = match args.model { Some(model) => std::path::PathBuf::from(model), - None => hf_hub::api::sync::Api::new()? - .model("lmz/candle-onnx".into()) - .get("squeezenet1.1-7.onnx")?, + None => match args.which { + Which::SqueezeNet => hf_hub::api::sync::Api::new()? + .model("lmz/candle-onnx".into()) + .get("squeezenet1.1-7.onnx")?, + Which::EfficientNet => hf_hub::api::sync::Api::new()? + .model("onnx/EfficientNet-Lite4".into()) + .get("efficientnet-lite4-11.onnx")?, + }, }; let model = candle_onnx::read_file(model)?; @@ -34,10 +53,12 @@ pub fn main() -> anyhow::Result<()> { let mut inputs = std::collections::HashMap::new(); inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?); let mut outputs = candle_onnx::simple_eval(&model, inputs)?; - let logits = outputs.remove(&graph.output[0].name).unwrap(); - let prs = candle_nn::ops::softmax(&logits, D::Minus1)? - .i(0)? - .to_vec1::()?; + let output = outputs.remove(&graph.output[0].name).unwrap(); + let prs = match args.which { + Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?, + Which::EfficientNet => output, + }; + let prs = prs.i(0)?.to_vec1::()?; // Sort the predictions and take the top 5 let mut top: Vec<_> = prs.iter().enumerate().collect(); diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index c1c98101..54fae6c1 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -30,6 +30,13 @@ impl Attr for i64 { } } +impl Attr for f32 { + const TYPE: AttributeType = AttributeType::Float; + fn get(attr: &onnx::AttributeProto) -> Result<&Self> { + Ok(&attr.f) + } +} + impl Attr for [i64] { const TYPE: AttributeType = AttributeType::Ints; fn get(attr: &onnx::AttributeProto) -> Result<&Self> { @@ -134,12 +141,66 @@ pub fn simple_eval( None => bail!("no graph defined in proto"), Some(graph) => graph, }; - // TODO: validate the inputs. let mut values = inputs; for t in graph.initializer.iter() { let tensor = get_tensor(t, t.name.as_str())?; values.insert(t.name.to_string(), tensor); } + for input in graph.input.iter() { + let input_type = match &input.r#type { + Some(input_type) => input_type, + None => continue, + }; + let input_type = match &input_type.value { + Some(input_type) => input_type, + None => continue, + }; + let tensor_type = match input_type { + onnx::type_proto::Value::TensorType(tt) => tt, + _ => continue, + }; + + let tensor = match values.get(&input.name) { + None => bail!("missing input {}", input.name), + Some(tensor) => tensor, + }; + let dt = match DataType::try_from(tensor_type.elem_type) { + Ok(dt) => match dtype(dt) { + Some(dt) => dt, + None => { + bail!("unsupported 'value' data-type {dt:?} for {}", input.name) + } + }, + type_ => bail!("unsupported input type {type_:?}"), + }; + let shape = match &tensor_type.shape { + None => continue, + Some(shape) => shape + .dim + .iter() + .map(|dim| match dim.value.as_ref().expect("no dim value") { + onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), + onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { + bail!("DimParam is unsupported for input {}", input.name) + } + }) + .collect::>>()?, + }; + if dt != tensor.dtype() { + bail!( + "unexpected dtype for {}, got {:?}, expected {dt:?}", + input.name, + tensor.dtype() + ) + } + if shape.as_slice() != tensor.dims() { + bail!( + "unexpected shape for {}, got {:?}, expected {shape:?}", + input.name, + tensor.dims() + ) + } + } // The nodes are topologically sorted so we can just process them in order. for node in graph.node.iter() { let get = |input_name: &str| match values.get(input_name) { @@ -328,6 +389,79 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), ys); } + "BatchNormalization" => { + let training_mode = get_attr_opt::(node, "training_mode")?; + if training_mode.copied().unwrap_or(0) != 0 { + bail!("training mode is not supported for BatchNorm") + } + let eps = get_attr_opt::(node, "epsilon")? + .copied() + .unwrap_or(1e-5); + let xs = get(&node.input[0])?; + let weight = get(&node.input[1])?; + let bias = get(&node.input[2])?; + let running_mean = get(&node.input[3])?; + let running_var = get(&node.input[4])?; + let target_shape: Vec = xs + .dims() + .iter() + .enumerate() + .map(|(idx, v)| if idx == 1 { *v } else { 1 }) + .collect(); + let target_shape = target_shape.as_slice(); + let xs = xs + .broadcast_sub(&running_mean.reshape(target_shape)?)? + .broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?; + let weight = weight.reshape(target_shape)?; + let bias = bias.reshape(target_shape)?; + let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?; + values.insert(node.output[0].clone(), xs); + } + "Squeeze" => { + let xs = get(&node.input[0])?; + let mut axes = if node.input.len() <= 1 { + // contract all the dimensions with size 1 except the batch dim. + xs.dims() + .iter() + .enumerate() + .flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None }) + .collect() + } else { + get(&node.input[1])? + .to_vec1::()? + .iter() + .map(|&i| { + if i < 0 { + (xs.rank() as i64 + i) as usize + } else { + i as usize + } + }) + .collect::>() + }; + axes.sort(); + let mut xs = xs.clone(); + for &axis in axes.iter().rev() { + xs = xs.squeeze(axis)? + } + values.insert(node.output[0].clone(), xs); + } + "Clip" => { + let xs = get(&node.input[0])?; + let xs = if node.input.len() >= 2 { + let mins = get(&node.input[1])?; + xs.broadcast_maximum(mins)? + } else { + xs.clone() + }; + let xs = if node.input.len() >= 3 { + let maxs = get(&node.input[2])?; + xs.broadcast_minimum(maxs)? + } else { + xs.clone() + }; + values.insert(node.output[0].clone(), xs); + } "Conv" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv let dilations = get_attr_opt::<[i64]>(node, "dilations")?; @@ -344,17 +478,15 @@ pub fn simple_eval( let ws = get(&node.input[1])?; let ys = match ws.rank() { 3 => { - let pads = match pads { - None => 0, - Some([p]) => *p as usize, + let (pads, xs) = match pads { + None => (0, xs.clone()), + Some([p]) => (*p as usize, xs.clone()), Some([p1, p2]) => { if p1 != p2 { - bail!( - "left and right pad ({p1} <> {p2}) have to be the same {}", - node.name - ) + (0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?) + } else { + (*p1 as usize, xs.clone()) } - *p1 as usize } Some(pads) => { bail!("more pads than expected in conv1d {pads:?} {}", node.name) @@ -377,14 +509,19 @@ pub fn simple_eval( xs.conv1d(ws, pads, strides, dilations, groups as usize)? } 4 => { - let pads = match pads { - None => 0, - Some([p]) => *p as usize, - Some([p1, p2, p3, p4]) => { + let (pads, xs) = match pads { + None => (0, xs.clone()), + Some([p]) => (*p as usize, xs.clone()), + Some(&[p1, p2, p3, p4]) => { + let p1 = p1 as usize; + let p2 = p2 as usize; + let p3 = p3 as usize; + let p4 = p4 as usize; if p1 != p2 || p1 != p3 || p1 != p4 { - bail!("pads have to be the same {pads:?} {}", node.name) + (0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?) + } else { + (p1, xs.clone()) } - *p1 as usize } Some(pads) => { bail!("more pads than expected in conv2d {pads:?} {}", node.name)