diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 2a5d3635..133b2782 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -477,6 +477,12 @@ impl Tensor { broadcast_binary_op!(broadcast_div, div); broadcast_binary_op!(broadcast_maximum, maximum); broadcast_binary_op!(broadcast_minimum, minimum); + broadcast_binary_op!(broadcast_eq, eq); + broadcast_binary_op!(broadcast_ne, ne); + broadcast_binary_op!(broadcast_lt, lt); + broadcast_binary_op!(broadcast_le, le); + broadcast_binary_op!(broadcast_gt, gt); + broadcast_binary_op!(broadcast_ge, ge); unary_op!(recip, Recip); unary_op!(neg, Neg); @@ -2406,6 +2412,23 @@ impl Tensor { ) -> Result { self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) } + + /// Normalize a 'relative' axis value: positive values are kept, negative + /// values means counting the dimensions from the back. + pub fn normalize_axis(&self, axis: i64) -> Result { + let rank = self.rank() as i64; + if rank <= axis { + crate::bail!("axis {axis} is too large, tensor rank {rank}") + } else if 0 <= axis { + Ok(axis as usize) + } else { + let naxis = rank + axis; + if naxis < 0 { + crate::bail!("axis {axis} is too small, tensor rank {rank}") + } + Ok(naxis as usize) + } + } } macro_rules! bin_trait { diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 2bda25d9..9c839b95 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -69,3 +69,7 @@ required-features = ["pyo3"] [[example]] name = "onnx" required-features = ["onnx"] + +[[example]] +name = "onnx_basics" +required-features = ["onnx"] diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-examples/examples/onnx_basics.rs similarity index 96% rename from candle-onnx/examples/onnx_basics.rs rename to candle-examples/examples/onnx_basics.rs index 43940596..0a173717 100644 --- a/candle-onnx/examples/onnx_basics.rs +++ b/candle-examples/examples/onnx_basics.rs @@ -67,7 +67,7 @@ pub fn main() -> Result<()> { .iter() .map(|dim| match dim.value.as_ref().expect("no dim value") { candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), - candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name), + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42), }) .collect::>>()?; Tensor::zeros(dims, dt, &Device::Cpu)? diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 54fae6c1..51e2aa0c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -101,6 +101,18 @@ fn get_attr_opt<'a, T: Attr + ?Sized>( fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result { let dims: Vec = t.dims.iter().map(|&x| x as usize).collect(); match DataType::try_from(t.data_type) { + Ok(DataType::Int32) => { + if t.int32_data.is_empty() { + let len = t.raw_data.len() / 4; + let data: &[i32] = + unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) }; + let data = data.iter().map(|v| *v as i64).collect::>(); + Tensor::from_vec(data, len, &Device::Cpu) + } else { + let data = t.int32_data.iter().map(|v| *v as i64).collect::>(); + Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu) + } + } Ok(dt) => match dtype(dt) { Some(dt) => { if dt == DType::F32 && !t.float_data.is_empty() { @@ -173,18 +185,34 @@ pub fn simple_eval( }, type_ => bail!("unsupported input type {type_:?}"), }; - let shape = match &tensor_type.shape { + match &tensor_type.shape { None => continue, - Some(shape) => shape - .dim - .iter() - .map(|dim| match dim.value.as_ref().expect("no dim value") { - onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), - onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { - bail!("DimParam is unsupported for input {}", input.name) + Some(shape) => { + if shape.dim.len() != tensor.rank() { + bail!( + "unexpected rank for {}, got {:?}, expected {:?}", + input.name, + shape.dim, + tensor.shape() + ) + } + for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() { + match &d.value { + Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => { + if *v as usize != dim { + bail!( + "unexpected dim {idx} for {}, got {:?}, expected {:?}", + input.name, + shape.dim, + tensor.shape() + ) + } + } + // We do not check equality constraints for the DimParam dimensions for now. + Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (), } - }) - .collect::>>()?, + } + } }; if dt != tensor.dtype() { bail!( @@ -193,13 +221,6 @@ pub fn simple_eval( tensor.dtype() ) } - if shape.as_slice() != tensor.dims() { - bail!( - "unexpected shape for {}, got {:?}, expected {shape:?}", - input.name, - tensor.dims() - ) - } } // The nodes are topologically sorted so we can just process them in order. for node in graph.node.iter() { @@ -236,9 +257,14 @@ pub fn simple_eval( "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; - let output = input0.eq(input1)?; + let output = input0.broadcast_eq(input1)?; values.insert(node.output[0].clone(), output); } + "Not" => { + let xs = get(&node.input[0])?; + let xs = xs.eq(&xs.zeros_like()?)?; + values.insert(node.output[0].clone(), xs); + } "MatMul" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; @@ -430,14 +456,8 @@ pub fn simple_eval( get(&node.input[1])? .to_vec1::()? .iter() - .map(|&i| { - if i < 0 { - (xs.rank() as i64 + i) as usize - } else { - i as usize - } - }) - .collect::>() + .map(|&i| xs.normalize_axis(i)) + .collect::>>()? }; axes.sort(); let mut xs = xs.clone(); @@ -446,6 +466,39 @@ pub fn simple_eval( } values.insert(node.output[0].clone(), xs); } + "ConstantOfShape" => { + let dims = get(&node.input[0])?; + let shape = dims + .to_vec1::()? + .into_iter() + .map(|v| v as usize) + .collect::>(); + let xs = Tensor::zeros(shape, DType::F32, dims.device())?; + values.insert(node.output[0].clone(), xs); + } + "Unsqueeze" => { + let xs = get(&node.input[0])?; + let axes = match get_attr_opt::<[i64]>(node, "axes")? { + Some(axis) => axis.to_vec(), + None => get(&node.input[1])?.to_vec1::()?, + }; + let mut axes = axes + .iter() + .map(|&i| { + if i == xs.rank() as i64 { + Ok(xs.rank()) + } else { + xs.normalize_axis(i) + } + }) + .collect::>>()?; + axes.sort(); + let mut xs = xs.clone(); + for &axis in axes.iter().rev() { + xs = xs.unsqueeze(axis)? + } + values.insert(node.output[0].clone(), xs); + } "Clip" => { let xs = get(&node.input[0])?; let xs = if node.input.len() >= 2 { @@ -462,6 +515,35 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), xs); } + "Gather" => { + let xs = get(&node.input[0])?; + let indices = get(&node.input[1])?; + let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); + let axis = xs.normalize_axis(axis)?; + // TODO: Provide an op to handle the ONNX generalized gather op ideally in a + // differentiable way. + let xs = if indices.rank() == 0 { + let index = indices.to_vec0::()? as usize; + xs.narrow(axis, index, 1)?.squeeze(axis)? + } else { + todo!("implement gather for {xs:?} {indices:?} axis {axis}") + }; + values.insert(node.output[0].clone(), xs); + } + "Shape" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape + let xs = get(&node.input[0])?; + let start = get_attr_opt::(node, "start")?.copied().unwrap_or(0); + let end = get_attr_opt::(node, "end")?.copied().unwrap_or(-1); + let start = xs.normalize_axis(start)?; + let end = xs.normalize_axis(end)?; + let mut dims = vec![]; + for idx in start..=end { + dims.push(xs.dim(idx)? as i64) + } + let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?; + values.insert(node.output[0].clone(), dims); + } "Conv" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv let dilations = get_attr_opt::<[i64]>(node, "dilations")?; @@ -670,6 +752,7 @@ pub fn simple_eval( let input = get(&node.input[0])?; let dt: i64 = *get_attr(node, "to")?; let dtype = match DataType::try_from(dt as i32) { + Ok(DataType::Int32) => DType::I64, Ok(dt) => match dtype(dt) { Some(dt) => dt, None => {