Add more models to the onnx example. (#1273)
* Add more models to the onnx example. * Input validation. * Input validation. * Bugfix. * Implement clip. * BatchNorm support. * Get the efficientnet onnx to work.
This commit is contained in:
parent
60fdab4e17
commit
f365a075e5
|
@ -5,7 +5,13 @@ extern crate intel_mkl_src;
|
|||
extern crate accelerate_src;
|
||||
|
||||
use candle::{IndexOp, D};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SqueezeNet,
|
||||
EfficientNet,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
|
@ -14,19 +20,32 @@ struct Args {
|
|||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = match args.which {
|
||||
Which::SqueezeNet => image,
|
||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||
};
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => hf_hub::api::sync::Api::new()?
|
||||
None => match args.which {
|
||||
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
||||
.model("lmz/candle-onnx".into())
|
||||
.get("squeezenet1.1-7.onnx")?,
|
||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx/EfficientNet-Lite4".into())
|
||||
.get("efficientnet-lite4-11.onnx")?,
|
||||
},
|
||||
};
|
||||
|
||||
let model = candle_onnx::read_file(model)?;
|
||||
|
@ -34,10 +53,12 @@ pub fn main() -> anyhow::Result<()> {
|
|||
let mut inputs = std::collections::HashMap::new();
|
||||
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
||||
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
let logits = outputs.remove(&graph.output[0].name).unwrap();
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let output = outputs.remove(&graph.output[0].name).unwrap();
|
||||
let prs = match args.which {
|
||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||
Which::EfficientNet => output,
|
||||
};
|
||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
|
@ -30,6 +30,13 @@ impl Attr for i64 {
|
|||
}
|
||||
}
|
||||
|
||||
impl Attr for f32 {
|
||||
const TYPE: AttributeType = AttributeType::Float;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
Ok(&attr.f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for [i64] {
|
||||
const TYPE: AttributeType = AttributeType::Ints;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
|
@ -134,12 +141,66 @@ pub fn simple_eval(
|
|||
None => bail!("no graph defined in proto"),
|
||||
Some(graph) => graph,
|
||||
};
|
||||
// TODO: validate the inputs.
|
||||
let mut values = inputs;
|
||||
for t in graph.initializer.iter() {
|
||||
let tensor = get_tensor(t, t.name.as_str())?;
|
||||
values.insert(t.name.to_string(), tensor);
|
||||
}
|
||||
for input in graph.input.iter() {
|
||||
let input_type = match &input.r#type {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
let input_type = match &input_type.value {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
let tensor_type = match input_type {
|
||||
onnx::type_proto::Value::TensorType(tt) => tt,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let tensor = match values.get(&input.name) {
|
||||
None => bail!("missing input {}", input.name),
|
||||
Some(tensor) => tensor,
|
||||
};
|
||||
let dt = match DataType::try_from(tensor_type.elem_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
|
||||
}
|
||||
},
|
||||
type_ => bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = match &tensor_type.shape {
|
||||
None => continue,
|
||||
Some(shape) => shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
|
||||
bail!("DimParam is unsupported for input {}", input.name)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?,
|
||||
};
|
||||
if dt != tensor.dtype() {
|
||||
bail!(
|
||||
"unexpected dtype for {}, got {:?}, expected {dt:?}",
|
||||
input.name,
|
||||
tensor.dtype()
|
||||
)
|
||||
}
|
||||
if shape.as_slice() != tensor.dims() {
|
||||
bail!(
|
||||
"unexpected shape for {}, got {:?}, expected {shape:?}",
|
||||
input.name,
|
||||
tensor.dims()
|
||||
)
|
||||
}
|
||||
}
|
||||
// The nodes are topologically sorted so we can just process them in order.
|
||||
for node in graph.node.iter() {
|
||||
let get = |input_name: &str| match values.get(input_name) {
|
||||
|
@ -328,6 +389,79 @@ pub fn simple_eval(
|
|||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"BatchNormalization" => {
|
||||
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
|
||||
if training_mode.copied().unwrap_or(0) != 0 {
|
||||
bail!("training mode is not supported for BatchNorm")
|
||||
}
|
||||
let eps = get_attr_opt::<f32>(node, "epsilon")?
|
||||
.copied()
|
||||
.unwrap_or(1e-5);
|
||||
let xs = get(&node.input[0])?;
|
||||
let weight = get(&node.input[1])?;
|
||||
let bias = get(&node.input[2])?;
|
||||
let running_mean = get(&node.input[3])?;
|
||||
let running_var = get(&node.input[4])?;
|
||||
let target_shape: Vec<usize> = xs
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||
.collect();
|
||||
let target_shape = target_shape.as_slice();
|
||||
let xs = xs
|
||||
.broadcast_sub(&running_mean.reshape(target_shape)?)?
|
||||
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
|
||||
let weight = weight.reshape(target_shape)?;
|
||||
let bias = bias.reshape(target_shape)?;
|
||||
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Squeeze" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let mut axes = if node.input.len() <= 1 {
|
||||
// contract all the dimensions with size 1 except the batch dim.
|
||||
xs.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
|
||||
.collect()
|
||||
} else {
|
||||
get(&node.input[1])?
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&i| {
|
||||
if i < 0 {
|
||||
(xs.rank() as i64 + i) as usize
|
||||
} else {
|
||||
i as usize
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
axes.sort();
|
||||
let mut xs = xs.clone();
|
||||
for &axis in axes.iter().rev() {
|
||||
xs = xs.squeeze(axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Clip" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let xs = if node.input.len() >= 2 {
|
||||
let mins = get(&node.input[1])?;
|
||||
xs.broadcast_maximum(mins)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
let xs = if node.input.len() >= 3 {
|
||||
let maxs = get(&node.input[2])?;
|
||||
xs.broadcast_minimum(maxs)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
|
@ -344,17 +478,15 @@ pub fn simple_eval(
|
|||
let ws = get(&node.input[1])?;
|
||||
let ys = match ws.rank() {
|
||||
3 => {
|
||||
let pads = match pads {
|
||||
None => 0,
|
||||
Some([p]) => *p as usize,
|
||||
let (pads, xs) = match pads {
|
||||
None => (0, xs.clone()),
|
||||
Some([p]) => (*p as usize, xs.clone()),
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
bail!(
|
||||
"left and right pad ({p1} <> {p2}) have to be the same {}",
|
||||
node.name
|
||||
)
|
||||
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
|
||||
} else {
|
||||
(*p1 as usize, xs.clone())
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(pads) => {
|
||||
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
|
||||
|
@ -377,14 +509,19 @@ pub fn simple_eval(
|
|||
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
|
||||
}
|
||||
4 => {
|
||||
let pads = match pads {
|
||||
None => 0,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2, p3, p4]) => {
|
||||
let (pads, xs) = match pads {
|
||||
None => (0, xs.clone()),
|
||||
Some([p]) => (*p as usize, xs.clone()),
|
||||
Some(&[p1, p2, p3, p4]) => {
|
||||
let p1 = p1 as usize;
|
||||
let p2 = p2 as usize;
|
||||
let p3 = p3 as usize;
|
||||
let p4 = p4 as usize;
|
||||
if p1 != p2 || p1 != p3 || p1 != p4 {
|
||||
bail!("pads have to be the same {pads:?} {}", node.name)
|
||||
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
|
||||
} else {
|
||||
(p1, xs.clone())
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(pads) => {
|
||||
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
|
||||
|
|
Loading…
Reference in New Issue