Add more models to the onnx example. (#1273)
* Add more models to the onnx example. * Input validation. * Input validation. * Bugfix. * Implement clip. * BatchNorm support. * Get the efficientnet onnx to work.
This commit is contained in:
parent
60fdab4e17
commit
f365a075e5
|
@ -5,7 +5,13 @@ extern crate intel_mkl_src;
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{IndexOp, D};
|
use candle::{IndexOp, D};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
SqueezeNet,
|
||||||
|
EfficientNet,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
@ -14,19 +20,32 @@ struct Args {
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The model to be used.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::SqueezeNet)]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
let image = match args.which {
|
||||||
|
Which::SqueezeNet => image,
|
||||||
|
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||||
|
};
|
||||||
|
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => hf_hub::api::sync::Api::new()?
|
None => match args.which {
|
||||||
|
Which::SqueezeNet => hf_hub::api::sync::Api::new()?
|
||||||
.model("lmz/candle-onnx".into())
|
.model("lmz/candle-onnx".into())
|
||||||
.get("squeezenet1.1-7.onnx")?,
|
.get("squeezenet1.1-7.onnx")?,
|
||||||
|
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("onnx/EfficientNet-Lite4".into())
|
||||||
|
.get("efficientnet-lite4-11.onnx")?,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let model = candle_onnx::read_file(model)?;
|
let model = candle_onnx::read_file(model)?;
|
||||||
|
@ -34,10 +53,12 @@ pub fn main() -> anyhow::Result<()> {
|
||||||
let mut inputs = std::collections::HashMap::new();
|
let mut inputs = std::collections::HashMap::new();
|
||||||
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?);
|
||||||
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
let mut outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
let logits = outputs.remove(&graph.output[0].name).unwrap();
|
let output = outputs.remove(&graph.output[0].name).unwrap();
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
let prs = match args.which {
|
||||||
.i(0)?
|
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||||
.to_vec1::<f32>()?;
|
Which::EfficientNet => output,
|
||||||
|
};
|
||||||
|
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
// Sort the predictions and take the top 5
|
// Sort the predictions and take the top 5
|
||||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
|
@ -30,6 +30,13 @@ impl Attr for i64 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Attr for f32 {
|
||||||
|
const TYPE: AttributeType = AttributeType::Float;
|
||||||
|
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||||
|
Ok(&attr.f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Attr for [i64] {
|
impl Attr for [i64] {
|
||||||
const TYPE: AttributeType = AttributeType::Ints;
|
const TYPE: AttributeType = AttributeType::Ints;
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||||
|
@ -134,12 +141,66 @@ pub fn simple_eval(
|
||||||
None => bail!("no graph defined in proto"),
|
None => bail!("no graph defined in proto"),
|
||||||
Some(graph) => graph,
|
Some(graph) => graph,
|
||||||
};
|
};
|
||||||
// TODO: validate the inputs.
|
|
||||||
let mut values = inputs;
|
let mut values = inputs;
|
||||||
for t in graph.initializer.iter() {
|
for t in graph.initializer.iter() {
|
||||||
let tensor = get_tensor(t, t.name.as_str())?;
|
let tensor = get_tensor(t, t.name.as_str())?;
|
||||||
values.insert(t.name.to_string(), tensor);
|
values.insert(t.name.to_string(), tensor);
|
||||||
}
|
}
|
||||||
|
for input in graph.input.iter() {
|
||||||
|
let input_type = match &input.r#type {
|
||||||
|
Some(input_type) => input_type,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
let input_type = match &input_type.value {
|
||||||
|
Some(input_type) => input_type,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
let tensor_type = match input_type {
|
||||||
|
onnx::type_proto::Value::TensorType(tt) => tt,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tensor = match values.get(&input.name) {
|
||||||
|
None => bail!("missing input {}", input.name),
|
||||||
|
Some(tensor) => tensor,
|
||||||
|
};
|
||||||
|
let dt = match DataType::try_from(tensor_type.elem_type) {
|
||||||
|
Ok(dt) => match dtype(dt) {
|
||||||
|
Some(dt) => dt,
|
||||||
|
None => {
|
||||||
|
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
type_ => bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
let shape = match &tensor_type.shape {
|
||||||
|
None => continue,
|
||||||
|
Some(shape) => shape
|
||||||
|
.dim
|
||||||
|
.iter()
|
||||||
|
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||||
|
onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||||
|
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
|
||||||
|
bail!("DimParam is unsupported for input {}", input.name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<usize>>>()?,
|
||||||
|
};
|
||||||
|
if dt != tensor.dtype() {
|
||||||
|
bail!(
|
||||||
|
"unexpected dtype for {}, got {:?}, expected {dt:?}",
|
||||||
|
input.name,
|
||||||
|
tensor.dtype()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if shape.as_slice() != tensor.dims() {
|
||||||
|
bail!(
|
||||||
|
"unexpected shape for {}, got {:?}, expected {shape:?}",
|
||||||
|
input.name,
|
||||||
|
tensor.dims()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
// The nodes are topologically sorted so we can just process them in order.
|
// The nodes are topologically sorted so we can just process them in order.
|
||||||
for node in graph.node.iter() {
|
for node in graph.node.iter() {
|
||||||
let get = |input_name: &str| match values.get(input_name) {
|
let get = |input_name: &str| match values.get(input_name) {
|
||||||
|
@ -328,6 +389,79 @@ pub fn simple_eval(
|
||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), ys);
|
values.insert(node.output[0].clone(), ys);
|
||||||
}
|
}
|
||||||
|
"BatchNormalization" => {
|
||||||
|
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
|
||||||
|
if training_mode.copied().unwrap_or(0) != 0 {
|
||||||
|
bail!("training mode is not supported for BatchNorm")
|
||||||
|
}
|
||||||
|
let eps = get_attr_opt::<f32>(node, "epsilon")?
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(1e-5);
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let weight = get(&node.input[1])?;
|
||||||
|
let bias = get(&node.input[2])?;
|
||||||
|
let running_mean = get(&node.input[3])?;
|
||||||
|
let running_var = get(&node.input[4])?;
|
||||||
|
let target_shape: Vec<usize> = xs
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||||
|
.collect();
|
||||||
|
let target_shape = target_shape.as_slice();
|
||||||
|
let xs = xs
|
||||||
|
.broadcast_sub(&running_mean.reshape(target_shape)?)?
|
||||||
|
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
|
||||||
|
let weight = weight.reshape(target_shape)?;
|
||||||
|
let bias = bias.reshape(target_shape)?;
|
||||||
|
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
|
"Squeeze" => {
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let mut axes = if node.input.len() <= 1 {
|
||||||
|
// contract all the dimensions with size 1 except the batch dim.
|
||||||
|
xs.dims()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
get(&node.input[1])?
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.iter()
|
||||||
|
.map(|&i| {
|
||||||
|
if i < 0 {
|
||||||
|
(xs.rank() as i64 + i) as usize
|
||||||
|
} else {
|
||||||
|
i as usize
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
axes.sort();
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for &axis in axes.iter().rev() {
|
||||||
|
xs = xs.squeeze(axis)?
|
||||||
|
}
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
|
"Clip" => {
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let xs = if node.input.len() >= 2 {
|
||||||
|
let mins = get(&node.input[1])?;
|
||||||
|
xs.broadcast_maximum(mins)?
|
||||||
|
} else {
|
||||||
|
xs.clone()
|
||||||
|
};
|
||||||
|
let xs = if node.input.len() >= 3 {
|
||||||
|
let maxs = get(&node.input[2])?;
|
||||||
|
xs.broadcast_minimum(maxs)?
|
||||||
|
} else {
|
||||||
|
xs.clone()
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
"Conv" => {
|
"Conv" => {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||||
|
@ -344,17 +478,15 @@ pub fn simple_eval(
|
||||||
let ws = get(&node.input[1])?;
|
let ws = get(&node.input[1])?;
|
||||||
let ys = match ws.rank() {
|
let ys = match ws.rank() {
|
||||||
3 => {
|
3 => {
|
||||||
let pads = match pads {
|
let (pads, xs) = match pads {
|
||||||
None => 0,
|
None => (0, xs.clone()),
|
||||||
Some([p]) => *p as usize,
|
Some([p]) => (*p as usize, xs.clone()),
|
||||||
Some([p1, p2]) => {
|
Some([p1, p2]) => {
|
||||||
if p1 != p2 {
|
if p1 != p2 {
|
||||||
bail!(
|
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
|
||||||
"left and right pad ({p1} <> {p2}) have to be the same {}",
|
} else {
|
||||||
node.name
|
(*p1 as usize, xs.clone())
|
||||||
)
|
|
||||||
}
|
}
|
||||||
*p1 as usize
|
|
||||||
}
|
}
|
||||||
Some(pads) => {
|
Some(pads) => {
|
||||||
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
|
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
|
||||||
|
@ -377,14 +509,19 @@ pub fn simple_eval(
|
||||||
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
|
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
|
||||||
}
|
}
|
||||||
4 => {
|
4 => {
|
||||||
let pads = match pads {
|
let (pads, xs) = match pads {
|
||||||
None => 0,
|
None => (0, xs.clone()),
|
||||||
Some([p]) => *p as usize,
|
Some([p]) => (*p as usize, xs.clone()),
|
||||||
Some([p1, p2, p3, p4]) => {
|
Some(&[p1, p2, p3, p4]) => {
|
||||||
|
let p1 = p1 as usize;
|
||||||
|
let p2 = p2 as usize;
|
||||||
|
let p3 = p3 as usize;
|
||||||
|
let p4 = p4 as usize;
|
||||||
if p1 != p2 || p1 != p3 || p1 != p4 {
|
if p1 != p2 || p1 != p3 || p1 != p4 {
|
||||||
bail!("pads have to be the same {pads:?} {}", node.name)
|
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
|
||||||
|
} else {
|
||||||
|
(p1, xs.clone())
|
||||||
}
|
}
|
||||||
*p1 as usize
|
|
||||||
}
|
}
|
||||||
Some(pads) => {
|
Some(pads) => {
|
||||||
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
|
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
|
||||||
|
|
Loading…
Reference in New Issue