[ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX. * Share the axis normalization bits. * Add some limited support for gather. * Unsqueeze. * Comparison with broadcasting. * Add Not + handle i32.
This commit is contained in:
parent
5a363dbc26
commit
a773a4b22b
|
@ -477,6 +477,12 @@ impl Tensor {
|
|||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
broadcast_binary_op!(broadcast_eq, eq);
|
||||
broadcast_binary_op!(broadcast_ne, ne);
|
||||
broadcast_binary_op!(broadcast_lt, lt);
|
||||
broadcast_binary_op!(broadcast_le, le);
|
||||
broadcast_binary_op!(broadcast_gt, gt);
|
||||
broadcast_binary_op!(broadcast_ge, ge);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
|
@ -2406,6 +2412,23 @@ impl Tensor {
|
|||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
let rank = self.rank() as i64;
|
||||
if rank <= axis {
|
||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
} else if 0 <= axis {
|
||||
Ok(axis as usize)
|
||||
} else {
|
||||
let naxis = rank + axis;
|
||||
if naxis < 0 {
|
||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
}
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
|
|
@ -69,3 +69,7 @@ required-features = ["pyo3"]
|
|||
[[example]]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
|
|
@ -67,7 +67,7 @@ pub fn main() -> Result<()> {
|
|||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
|
@ -101,6 +101,18 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
|
|||
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
Ok(DataType::Int32) => {
|
||||
if t.int32_data.is_empty() {
|
||||
let len = t.raw_data.len() / 4;
|
||||
let data: &[i32] =
|
||||
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
|
||||
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
Tensor::from_vec(data, len, &Device::Cpu)
|
||||
} else {
|
||||
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
|
||||
}
|
||||
}
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => {
|
||||
if dt == DType::F32 && !t.float_data.is_empty() {
|
||||
|
@ -173,18 +185,34 @@ pub fn simple_eval(
|
|||
},
|
||||
type_ => bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = match &tensor_type.shape {
|
||||
match &tensor_type.shape {
|
||||
None => continue,
|
||||
Some(shape) => shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
|
||||
bail!("DimParam is unsupported for input {}", input.name)
|
||||
Some(shape) => {
|
||||
if shape.dim.len() != tensor.rank() {
|
||||
bail!(
|
||||
"unexpected rank for {}, got {:?}, expected {:?}",
|
||||
input.name,
|
||||
shape.dim,
|
||||
tensor.shape()
|
||||
)
|
||||
}
|
||||
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
|
||||
match &d.value {
|
||||
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
||||
if *v as usize != dim {
|
||||
bail!(
|
||||
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
|
||||
input.name,
|
||||
shape.dim,
|
||||
tensor.shape()
|
||||
)
|
||||
}
|
||||
}
|
||||
// We do not check equality constraints for the DimParam dimensions for now.
|
||||
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?,
|
||||
}
|
||||
}
|
||||
};
|
||||
if dt != tensor.dtype() {
|
||||
bail!(
|
||||
|
@ -193,13 +221,6 @@ pub fn simple_eval(
|
|||
tensor.dtype()
|
||||
)
|
||||
}
|
||||
if shape.as_slice() != tensor.dims() {
|
||||
bail!(
|
||||
"unexpected shape for {}, got {:?}, expected {shape:?}",
|
||||
input.name,
|
||||
tensor.dims()
|
||||
)
|
||||
}
|
||||
}
|
||||
// The nodes are topologically sorted so we can just process them in order.
|
||||
for node in graph.node.iter() {
|
||||
|
@ -236,9 +257,14 @@ pub fn simple_eval(
|
|||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.eq(input1)?;
|
||||
let output = input0.broadcast_eq(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Not" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let xs = xs.eq(&xs.zeros_like()?)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"MatMul" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
|
@ -430,14 +456,8 @@ pub fn simple_eval(
|
|||
get(&node.input[1])?
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&i| {
|
||||
if i < 0 {
|
||||
(xs.rank() as i64 + i) as usize
|
||||
} else {
|
||||
i as usize
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.map(|&i| xs.normalize_axis(i))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
};
|
||||
axes.sort();
|
||||
let mut xs = xs.clone();
|
||||
|
@ -446,6 +466,39 @@ pub fn simple_eval(
|
|||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"ConstantOfShape" => {
|
||||
let dims = get(&node.input[0])?;
|
||||
let shape = dims
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|v| v as usize)
|
||||
.collect::<Vec<_>>();
|
||||
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Unsqueeze" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
|
||||
Some(axis) => axis.to_vec(),
|
||||
None => get(&node.input[1])?.to_vec1::<i64>()?,
|
||||
};
|
||||
let mut axes = axes
|
||||
.iter()
|
||||
.map(|&i| {
|
||||
if i == xs.rank() as i64 {
|
||||
Ok(xs.rank())
|
||||
} else {
|
||||
xs.normalize_axis(i)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
axes.sort();
|
||||
let mut xs = xs.clone();
|
||||
for &axis in axes.iter().rev() {
|
||||
xs = xs.unsqueeze(axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Clip" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let xs = if node.input.len() >= 2 {
|
||||
|
@ -462,6 +515,35 @@ pub fn simple_eval(
|
|||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Gather" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
let xs = if indices.rank() == 0 {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
} else {
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Shape" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||
let xs = get(&node.input[0])?;
|
||||
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
|
||||
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
|
||||
let start = xs.normalize_axis(start)?;
|
||||
let end = xs.normalize_axis(end)?;
|
||||
let mut dims = vec![];
|
||||
for idx in start..=end {
|
||||
dims.push(xs.dim(idx)? as i64)
|
||||
}
|
||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||
values.insert(node.output[0].clone(), dims);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
|
@ -670,6 +752,7 @@ pub fn simple_eval(
|
|||
let input = get(&node.input[0])?;
|
||||
let dt: i64 = *get_attr(node, "to")?;
|
||||
let dtype = match DataType::try_from(dt as i32) {
|
||||
Ok(DataType::Int32) => DType::I64,
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
|
|
Loading…
Reference in New Issue