[ONNX] Support a couple more ops. (#1284)

* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.
This commit is contained in:
Laurent Mazare 2023-11-06 22:44:58 +01:00 committed by GitHub
parent 5a363dbc26
commit a773a4b22b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 27 deletions

View File

@ -477,6 +477,12 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
@ -2406,6 +2412,23 @@ impl Tensor {
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
}
macro_rules! bin_trait {

View File

@ -69,3 +69,7 @@ required-features = ["pyo3"]
[[example]]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -67,7 +67,7 @@ pub fn main() -> Result<()> {
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?

View File

@ -101,6 +101,18 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
if t.int32_data.is_empty() {
let len = t.raw_data.len() / 4;
let data: &[i32] =
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, len, &Device::Cpu)
} else {
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
}
}
Ok(dt) => match dtype(dt) {
Some(dt) => {
if dt == DType::F32 && !t.float_data.is_empty() {
@ -173,18 +185,34 @@ pub fn simple_eval(
},
type_ => bail!("unsupported input type {type_:?}"),
};
let shape = match &tensor_type.shape {
match &tensor_type.shape {
None => continue,
Some(shape) => shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
bail!("DimParam is unsupported for input {}", input.name)
Some(shape) => {
if shape.dim.len() != tensor.rank() {
bail!(
"unexpected rank for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
match &d.value {
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
if *v as usize != dim {
bail!(
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
}
// We do not check equality constraints for the DimParam dimensions for now.
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
}
})
.collect::<Result<Vec<usize>>>()?,
}
}
};
if dt != tensor.dtype() {
bail!(
@ -193,13 +221,6 @@ pub fn simple_eval(
tensor.dtype()
)
}
if shape.as_slice() != tensor.dims() {
bail!(
"unexpected shape for {}, got {:?}, expected {shape:?}",
input.name,
tensor.dims()
)
}
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
@ -236,9 +257,14 @@ pub fn simple_eval(
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.eq(input1)?;
let output = input0.broadcast_eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"Not" => {
let xs = get(&node.input[0])?;
let xs = xs.eq(&xs.zeros_like()?)?;
values.insert(node.output[0].clone(), xs);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
@ -430,14 +456,8 @@ pub fn simple_eval(
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
.map(|&i| {
if i < 0 {
(xs.rank() as i64 + i) as usize
} else {
i as usize
}
})
.collect::<Vec<_>>()
.map(|&i| xs.normalize_axis(i))
.collect::<Result<Vec<_>>>()?
};
axes.sort();
let mut xs = xs.clone();
@ -446,6 +466,39 @@ pub fn simple_eval(
}
values.insert(node.output[0].clone(), xs);
}
"ConstantOfShape" => {
let dims = get(&node.input[0])?;
let shape = dims
.to_vec1::<i64>()?
.into_iter()
.map(|v| v as usize)
.collect::<Vec<_>>();
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
let xs = get(&node.input[0])?;
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
Some(axis) => axis.to_vec(),
None => get(&node.input[1])?.to_vec1::<i64>()?,
};
let mut axes = axes
.iter()
.map(|&i| {
if i == xs.rank() as i64 {
Ok(xs.rank())
} else {
xs.normalize_axis(i)
}
})
.collect::<Result<Vec<_>>>()?;
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.unsqueeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
@ -462,6 +515,35 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
let xs = if indices.rank() == 0 {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
} else {
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
};
values.insert(node.output[0].clone(), xs);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
let start = xs.normalize_axis(start)?;
let end = xs.normalize_axis(end)?;
let mut dims = vec![];
for idx in start..=end {
dims.push(xs.dim(idx)? as i64)
}
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
@ -670,6 +752,7 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(DataType::Int32) => DType::I64,
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {