Add more models to the onnx example. (#1273)

* Add more models to the onnx example.

* Input validation.

* Input validation.

* Bugfix.

* Implement clip.

* BatchNorm support.

* Get the efficientnet onnx to work.
This commit is contained in:
Laurent Mazare 2023-11-05 16:57:26 +01:00 committed by GitHub
parent 60fdab4e17
commit f365a075e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 181 additions and 23 deletions

View File

@ -5,7 +5,13 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::{IndexOp, D};
use clap::Parser;
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
}
#[derive(Parser)]
struct Args {
@ -14,19 +20,32 @@ struct Args {
#[arg(long)]
model: Option<String>,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
};
println!("loaded image {image:?}");
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => hf_hub::api::sync::Api::new()?
.model("lmz/candle-onnx".into())
.get("squeezenet1.1-7.onnx")?,
None => match args.which {
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
.model("lmz/candle-onnx".into())
.get("squeezenet1.1-7.onnx")?,
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
},
};
let model = candle_onnx::read_file(model)?;
@ -34,10 +53,12 @@ pub fn main() -> anyhow::Result<()> {
let mut inputs = std::collections::HashMap::new();
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
let logits = outputs.remove(&graph.output[0].name).unwrap();
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let output = outputs.remove(&graph.output[0].name).unwrap();
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();

View File

@ -30,6 +30,13 @@ impl Attr for i64 {
}
}
impl Attr for f32 {
const TYPE: AttributeType = AttributeType::Float;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.f)
}
}
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
@ -134,12 +141,66 @@ pub fn simple_eval(
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
// TODO: validate the inputs.
let mut values = inputs;
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
for input in graph.input.iter() {
let input_type = match &input.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type = match input_type {
onnx::type_proto::Value::TensorType(tt) => tt,
_ => continue,
};
let tensor = match values.get(&input.name) {
None => bail!("missing input {}", input.name),
Some(tensor) => tensor,
};
let dt = match DataType::try_from(tensor_type.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
}
},
type_ => bail!("unsupported input type {type_:?}"),
};
let shape = match &tensor_type.shape {
None => continue,
Some(shape) => shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
bail!("DimParam is unsupported for input {}", input.name)
}
})
.collect::<Result<Vec<usize>>>()?,
};
if dt != tensor.dtype() {
bail!(
"unexpected dtype for {}, got {:?}, expected {dt:?}",
input.name,
tensor.dtype()
)
}
if shape.as_slice() != tensor.dims() {
bail!(
"unexpected shape for {}, got {:?}, expected {shape:?}",
input.name,
tensor.dims()
)
}
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
@ -328,6 +389,79 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), ys);
}
"BatchNormalization" => {
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
if training_mode.copied().unwrap_or(0) != 0 {
bail!("training mode is not supported for BatchNorm")
}
let eps = get_attr_opt::<f32>(node, "epsilon")?
.copied()
.unwrap_or(1e-5);
let xs = get(&node.input[0])?;
let weight = get(&node.input[1])?;
let bias = get(&node.input[2])?;
let running_mean = get(&node.input[3])?;
let running_var = get(&node.input[4])?;
let target_shape: Vec<usize> = xs
.dims()
.iter()
.enumerate()
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
let xs = xs
.broadcast_sub(&running_mean.reshape(target_shape)?)?
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
let weight = weight.reshape(target_shape)?;
let bias = bias.reshape(target_shape)?;
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
values.insert(node.output[0].clone(), xs);
}
"Squeeze" => {
let xs = get(&node.input[0])?;
let mut axes = if node.input.len() <= 1 {
// contract all the dimensions with size 1 except the batch dim.
xs.dims()
.iter()
.enumerate()
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
.collect()
} else {
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
.map(|&i| {
if i < 0 {
(xs.rank() as i64 + i) as usize
} else {
i as usize
}
})
.collect::<Vec<_>>()
};
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.squeeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
} else {
xs.clone()
};
values.insert(node.output[0].clone(), xs);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
@ -344,17 +478,15 @@ pub fn simple_eval(
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
let pads = match pads {
None => 0,
Some([p]) => *p as usize,
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"left and right pad ({p1} <> {p2}) have to be the same {}",
node.name
)
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
} else {
(*p1 as usize, xs.clone())
}
*p1 as usize
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
@ -377,14 +509,19 @@ pub fn simple_eval(
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
let pads = match pads {
None => 0,
Some([p]) => *p as usize,
Some([p1, p2, p3, p4]) => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some(&[p1, p2, p3, p4]) => {
let p1 = p1 as usize;
let p2 = p2 as usize;
let p3 = p3 as usize;
let p4 = p4 as usize;
if p1 != p2 || p1 != p3 || p1 != p4 {
bail!("pads have to be the same {pads:?} {}", node.name)
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
} else {
(p1, xs.clone())
}
*p1 as usize
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)