diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 73376fbe..bc04eb00 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,4 +1,5 @@ use crate::onnx; +use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; use candle::{bail, DType, Device, Result, Tensor}; use std::collections::HashMap; @@ -17,6 +18,96 @@ pub fn dtype(dt: DataType) -> Option { } } +trait Attr { + const TYPE: AttributeType; + fn get(attr: &onnx::AttributeProto) -> Result<&Self>; +} + +impl Attr for i64 { + const TYPE: AttributeType = AttributeType::Int; + fn get(attr: &onnx::AttributeProto) -> Result<&Self> { + Ok(&attr.i) + } +} + +impl Attr for [i64] { + const TYPE: AttributeType = AttributeType::Ints; + fn get(attr: &onnx::AttributeProto) -> Result<&Self> { + Ok(attr.ints.as_slice()) + } +} + +impl Attr for str { + const TYPE: AttributeType = AttributeType::String; + fn get(attr: &onnx::AttributeProto) -> Result<&Self> { + std::str::from_utf8(&attr.s).map_err(candle::Error::wrap) + } +} + +fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> { + match node.attribute.iter().find(|attr| attr.name == name) { + None => { + bail!( + "cannot find the '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ) + } + Some(dt) => Ok(dt), + } +} + +fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> { + let attr = get_attr_(node, name)?; + if attr.r#type() != T::TYPE { + bail!( + "unsupported type {:?} for '{name}' attribute in '{}' for {}", + attr.r#type, + node.op_type, + node.name + ) + } + T::get(attr) +} + +fn get_attr_opt<'a, T: Attr + ?Sized>( + node: &'a onnx::NodeProto, + name: &str, +) -> Result> { + match node.attribute.iter().find(|attr| attr.name == name) { + None => Ok(None), + Some(attr) => { + if attr.r#type() != T::TYPE { + bail!( + "unsupported type {:?} for '{name}' attribute in '{}' for {}", + attr.r#type, + node.op_type, + node.name + ) + } + let val = T::get(attr)?; + Ok(Some(val)) + } + } +} + +fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result { + let dims: Vec = t.dims.iter().map(|&x| x as usize).collect(); + match DataType::try_from(t.data_type) { + Ok(dt) => match dtype(dt) { + Some(dt) => { + Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu) + } + None => { + bail!("unsupported 'value' data-type {dt:?} for {name}") + } + }, + Err(_) => { + bail!("unsupported 'value' data-type {} for {name}", t.data_type,) + } + } +} + // This function provides a direct evaluation of the proto. // Longer-term, we should first convert the proto to an intermediate representation of the compute // graph so as to make multiple evaluations more efficient. @@ -26,59 +117,22 @@ pub fn simple_eval( model: &onnx::ModelProto, inputs: HashMap, ) -> Result> { - use crate::onnx::attribute_proto::AttributeType; let graph = match &model.graph { None => bail!("no graph defined in proto"), Some(graph) => graph, }; // TODO: validate the inputs. let mut values = inputs; + for t in graph.initializer.iter() { + let tensor = get_tensor(t, t.name.as_str())?; + values.insert(t.name.to_string(), tensor); + } // The nodes are topologically sorted so we can just process them in order. for node in graph.node.iter() { let get = |input_name: &str| match values.get(input_name) { Some(value) => Ok(value), None => bail!("cannot find {input_name} for op {}", node.name), }; - let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) { - None => { - bail!( - "cannot find the '{name}' attribute in '{}' for {}", - node.op_type, - node.name - ) - } - Some(dt) => { - match dt.r#type() { - AttributeType::Int => (), - rtype => bail!( - "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}", - node.op_type, - node.name - ), - } - Ok(dt.i) - } - }; - let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) { - None => { - bail!( - "cannot find the '{name}' attribute in '{}' for {}", - node.op_type, - node.name - ) - } - Some(dt) => { - match dt.r#type() { - AttributeType::Ints => (), - rtype => bail!( - "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}", - node.op_type, - node.name - ), - } - Ok(dt.ints.as_slice()) - } - }; // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { @@ -136,9 +190,9 @@ pub fn simple_eval( } "LogSoftmax" => { let input = get(&node.input[0])?; - let output = match get_attr_i("axis") { - Err(_) => candle_nn::ops::softmax_last_dim(input)?, - Ok(axis) => { + let output = match get_attr_opt::(node, "axis")? { + None => candle_nn::ops::softmax_last_dim(input)?, + Some(&axis) => { let num_axis = input.rank() as i64; let axis = if axis >= 0 { axis as usize @@ -154,9 +208,9 @@ pub fn simple_eval( } "Softmax" => { let input = get(&node.input[0])?; - let output = match get_attr_i("axis") { - Err(_) => candle_nn::ops::softmax_last_dim(input)?, - Ok(axis) => { + let output = match get_attr_opt::(node, "axis")? { + None => candle_nn::ops::softmax_last_dim(input)?, + Some(&axis) => { let num_axis = input.rank() as i64; let axis = if axis >= 0 { axis as usize @@ -172,15 +226,126 @@ pub fn simple_eval( } "Transpose" => { let input = get(&node.input[0])?; - let output = match get_attr_is("perm") { - Err(_) => input.t()?, - Ok(perm) => { + let output = match get_attr_opt::<[i64]>(node, "perm")? { + None => input.t()?, + Some(perm) => { let perm = perm.iter().map(|&v| v as usize).collect::>(); input.permute(perm)? } }; values.insert(node.output[0].clone(), output); } + "Conv" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv + let dilations = get_attr_opt::<[i64]>(node, "dilations")?; + let groups = get_attr_opt::(node, "group")?.copied().unwrap_or(1); + let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?; + let pads = get_attr_opt::<[i64]>(node, "pads")?; + let strides = get_attr_opt::<[i64]>(node, "strides")?; + let auto_pad = get_attr_opt::(node, "auto_pad")?; + match auto_pad { + None | Some("NOTSET") => (), + Some(s) => bail!("unsupported auto_pad {s}"), + }; + let xs = get(&node.input[0])?; + let ws = get(&node.input[1])?; + let ys = match ws.rank() { + 3 => { + let pads = match pads { + None => 0, + Some([p]) => *p as usize, + Some([p1, p2]) => { + if p1 != p2 { + bail!( + "left and right pad ({p1} <> {p2}) have to be the same {}", + node.name + ) + } + *p1 as usize + } + Some(pads) => { + bail!("more pads than expected in conv1d {pads:?} {}", node.name) + } + }; + let strides = match strides { + None => 1, + Some([p]) => *p as usize, + Some(s) => { + bail!("more strides than expected in conv1d {s:?} {}", node.name) + } + }; + let dilations = match dilations { + None => 1, + Some([p]) => *p as usize, + Some(s) => { + bail!("more dilations than expected in conv1d {s:?} {}", node.name) + } + }; + xs.conv1d(ws, pads, strides, dilations, groups as usize)? + } + 4 => { + let pads = match pads { + None => 0, + Some([p]) => *p as usize, + Some([p1, p2, p3, p4]) => { + if p1 != p2 || p1 != p3 || p1 != p4 { + bail!("pads to be the same {pads:?} {}", node.name) + } + *p1 as usize + } + Some(pads) => { + bail!("more pads than expected in conv2d {pads:?} {}", node.name) + } + }; + let strides = match strides { + None => 1, + Some([p]) => *p as usize, + Some([p1, p2]) => { + if p1 != p2 { + bail!( + "strides to be the same on both axis {pads:?} {}", + node.name + ) + } + *p1 as usize + } + Some(s) => { + bail!("more strides than expected in conv2d {s:?} {}", node.name) + } + }; + let dilations = match dilations { + None => 1, + Some([p]) => *p as usize, + Some([p1, p2]) => { + if p1 != p2 { + bail!( + "dilations to be the same on both axis {pads:?} {}", + node.name + ) + } + *p1 as usize + } + Some(s) => { + bail!("more dilations than expected in conv2d {s:?} {}", node.name) + } + }; + xs.conv2d(ws, pads, strides, dilations, groups as usize)? + } + rank => bail!( + "unsupported rank for weight matrix {rank} in conv {}", + node.name + ), + }; + let ys = if node.input.len() > 2 { + let bs = get(&node.input[2])?; + let mut bs_shape = vec![1; ys.rank()]; + bs_shape[1] = bs.elem_count(); + ys.broadcast_add(&bs.reshape(bs_shape)?)? + } else { + ys + }; + values.insert(node.output[0].clone(), ys); + } "Concat" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat let inputs = node @@ -188,7 +353,7 @@ pub fn simple_eval( .iter() .map(|n| Ok(get(n.as_str())?.clone())) .collect::>>()?; - let axis = get_attr_i("axis")?; + let axis: i64 = *get_attr(node, "axis")?; let num_axis = if inputs.is_empty() { bail!("empty concat") } else { @@ -264,27 +429,7 @@ pub fn simple_eval( let output = match value.r#type() { AttributeType::Tensor => { let t = value.t.as_ref().unwrap(); - let dims: Vec = t.dims.iter().map(|&x| x as usize).collect(); - match DataType::try_from(t.data_type) { - Ok(dt) => match dtype(dt) { - Some(dt) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - dt, - dims.as_slice(), - &Device::Cpu, - )?, - None => { - bail!("unsupported 'value' data-type {dt:?} for {}", node.name) - } - }, - Err(_) => { - bail!( - "unsupported 'value' data-type {} for {}", - t.data_type, - node.name - ) - } - } + get_tensor(t, &node.name)? } rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name), }; @@ -293,7 +438,7 @@ pub fn simple_eval( // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast "Cast" => { let input = get(&node.input[0])?; - let dt = get_attr_i("to")?; + let dt: i64 = *get_attr(node, "to")?; let dtype = match DataType::try_from(dt as i32) { Ok(dt) => match dtype(dt) { Some(dt) => dt,