[ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX. * Share the axis normalization bits. * Add some limited support for gather. * Unsqueeze. * Comparison with broadcasting. * Add Not + handle i32.
This commit is contained in:
parent
5a363dbc26
commit
a773a4b22b
|
@ -477,6 +477,12 @@ impl Tensor {
|
||||||
broadcast_binary_op!(broadcast_div, div);
|
broadcast_binary_op!(broadcast_div, div);
|
||||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||||
|
broadcast_binary_op!(broadcast_eq, eq);
|
||||||
|
broadcast_binary_op!(broadcast_ne, ne);
|
||||||
|
broadcast_binary_op!(broadcast_lt, lt);
|
||||||
|
broadcast_binary_op!(broadcast_le, le);
|
||||||
|
broadcast_binary_op!(broadcast_gt, gt);
|
||||||
|
broadcast_binary_op!(broadcast_ge, ge);
|
||||||
|
|
||||||
unary_op!(recip, Recip);
|
unary_op!(recip, Recip);
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
|
@ -2406,6 +2412,23 @@ impl Tensor {
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||||
|
/// values means counting the dimensions from the back.
|
||||||
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
|
let rank = self.rank() as i64;
|
||||||
|
if rank <= axis {
|
||||||
|
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||||
|
} else if 0 <= axis {
|
||||||
|
Ok(axis as usize)
|
||||||
|
} else {
|
||||||
|
let naxis = rank + axis;
|
||||||
|
if naxis < 0 {
|
||||||
|
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||||
|
}
|
||||||
|
Ok(naxis as usize)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
|
|
@ -69,3 +69,7 @@ required-features = ["pyo3"]
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "onnx"
|
name = "onnx"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "onnx_basics"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
|
|
@ -67,7 +67,7 @@ pub fn main() -> Result<()> {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<usize>>>()?;
|
.collect::<Result<Vec<usize>>>()?;
|
||||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
Tensor::zeros(dims, dt, &Device::Cpu)?
|
|
@ -101,6 +101,18 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
|
||||||
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||||
match DataType::try_from(t.data_type) {
|
match DataType::try_from(t.data_type) {
|
||||||
|
Ok(DataType::Int32) => {
|
||||||
|
if t.int32_data.is_empty() {
|
||||||
|
let len = t.raw_data.len() / 4;
|
||||||
|
let data: &[i32] =
|
||||||
|
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
|
||||||
|
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||||
|
Tensor::from_vec(data, len, &Device::Cpu)
|
||||||
|
} else {
|
||||||
|
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
|
||||||
|
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(dt) => match dtype(dt) {
|
Ok(dt) => match dtype(dt) {
|
||||||
Some(dt) => {
|
Some(dt) => {
|
||||||
if dt == DType::F32 && !t.float_data.is_empty() {
|
if dt == DType::F32 && !t.float_data.is_empty() {
|
||||||
|
@ -173,18 +185,34 @@ pub fn simple_eval(
|
||||||
},
|
},
|
||||||
type_ => bail!("unsupported input type {type_:?}"),
|
type_ => bail!("unsupported input type {type_:?}"),
|
||||||
};
|
};
|
||||||
let shape = match &tensor_type.shape {
|
match &tensor_type.shape {
|
||||||
None => continue,
|
None => continue,
|
||||||
Some(shape) => shape
|
Some(shape) => {
|
||||||
.dim
|
if shape.dim.len() != tensor.rank() {
|
||||||
.iter()
|
bail!(
|
||||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
"unexpected rank for {}, got {:?}, expected {:?}",
|
||||||
onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
input.name,
|
||||||
onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
|
shape.dim,
|
||||||
bail!("DimParam is unsupported for input {}", input.name)
|
tensor.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
|
||||||
|
match &d.value {
|
||||||
|
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
|
||||||
|
if *v as usize != dim {
|
||||||
|
bail!(
|
||||||
|
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
|
||||||
|
input.name,
|
||||||
|
shape.dim,
|
||||||
|
tensor.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We do not check equality constraints for the DimParam dimensions for now.
|
||||||
|
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
.collect::<Result<Vec<usize>>>()?,
|
}
|
||||||
};
|
};
|
||||||
if dt != tensor.dtype() {
|
if dt != tensor.dtype() {
|
||||||
bail!(
|
bail!(
|
||||||
|
@ -193,13 +221,6 @@ pub fn simple_eval(
|
||||||
tensor.dtype()
|
tensor.dtype()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if shape.as_slice() != tensor.dims() {
|
|
||||||
bail!(
|
|
||||||
"unexpected shape for {}, got {:?}, expected {shape:?}",
|
|
||||||
input.name,
|
|
||||||
tensor.dims()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// The nodes are topologically sorted so we can just process them in order.
|
// The nodes are topologically sorted so we can just process them in order.
|
||||||
for node in graph.node.iter() {
|
for node in graph.node.iter() {
|
||||||
|
@ -236,9 +257,14 @@ pub fn simple_eval(
|
||||||
"Equal" => {
|
"Equal" => {
|
||||||
let input0 = get(&node.input[0])?;
|
let input0 = get(&node.input[0])?;
|
||||||
let input1 = get(&node.input[1])?;
|
let input1 = get(&node.input[1])?;
|
||||||
let output = input0.eq(input1)?;
|
let output = input0.broadcast_eq(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Not" => {
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let xs = xs.eq(&xs.zeros_like()?)?;
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
"MatMul" => {
|
"MatMul" => {
|
||||||
let input0 = get(&node.input[0])?;
|
let input0 = get(&node.input[0])?;
|
||||||
let input1 = get(&node.input[1])?;
|
let input1 = get(&node.input[1])?;
|
||||||
|
@ -430,14 +456,8 @@ pub fn simple_eval(
|
||||||
get(&node.input[1])?
|
get(&node.input[1])?
|
||||||
.to_vec1::<i64>()?
|
.to_vec1::<i64>()?
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&i| {
|
.map(|&i| xs.normalize_axis(i))
|
||||||
if i < 0 {
|
.collect::<Result<Vec<_>>>()?
|
||||||
(xs.rank() as i64 + i) as usize
|
|
||||||
} else {
|
|
||||||
i as usize
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
};
|
};
|
||||||
axes.sort();
|
axes.sort();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
|
@ -446,6 +466,39 @@ pub fn simple_eval(
|
||||||
}
|
}
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
|
"ConstantOfShape" => {
|
||||||
|
let dims = get(&node.input[0])?;
|
||||||
|
let shape = dims
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| v as usize)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
|
"Unsqueeze" => {
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
|
||||||
|
Some(axis) => axis.to_vec(),
|
||||||
|
None => get(&node.input[1])?.to_vec1::<i64>()?,
|
||||||
|
};
|
||||||
|
let mut axes = axes
|
||||||
|
.iter()
|
||||||
|
.map(|&i| {
|
||||||
|
if i == xs.rank() as i64 {
|
||||||
|
Ok(xs.rank())
|
||||||
|
} else {
|
||||||
|
xs.normalize_axis(i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
axes.sort();
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for &axis in axes.iter().rev() {
|
||||||
|
xs = xs.unsqueeze(axis)?
|
||||||
|
}
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
"Clip" => {
|
"Clip" => {
|
||||||
let xs = get(&node.input[0])?;
|
let xs = get(&node.input[0])?;
|
||||||
let xs = if node.input.len() >= 2 {
|
let xs = if node.input.len() >= 2 {
|
||||||
|
@ -462,6 +515,35 @@ pub fn simple_eval(
|
||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
|
"Gather" => {
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let indices = get(&node.input[1])?;
|
||||||
|
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||||
|
let axis = xs.normalize_axis(axis)?;
|
||||||
|
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||||
|
// differentiable way.
|
||||||
|
let xs = if indices.rank() == 0 {
|
||||||
|
let index = indices.to_vec0::<i64>()? as usize;
|
||||||
|
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||||
|
} else {
|
||||||
|
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), xs);
|
||||||
|
}
|
||||||
|
"Shape" => {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
|
||||||
|
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
|
||||||
|
let start = xs.normalize_axis(start)?;
|
||||||
|
let end = xs.normalize_axis(end)?;
|
||||||
|
let mut dims = vec![];
|
||||||
|
for idx in start..=end {
|
||||||
|
dims.push(xs.dim(idx)? as i64)
|
||||||
|
}
|
||||||
|
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||||
|
values.insert(node.output[0].clone(), dims);
|
||||||
|
}
|
||||||
"Conv" => {
|
"Conv" => {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||||
|
@ -670,6 +752,7 @@ pub fn simple_eval(
|
||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let dt: i64 = *get_attr(node, "to")?;
|
let dt: i64 = *get_attr(node, "to")?;
|
||||||
let dtype = match DataType::try_from(dt as i32) {
|
let dtype = match DataType::try_from(dt as i32) {
|
||||||
|
Ok(DataType::Int32) => DType::I64,
|
||||||
Ok(dt) => match dtype(dt) {
|
Ok(dt) => match dtype(dt) {
|
||||||
Some(dt) => dt,
|
Some(dt) => dt,
|
||||||
None => {
|
None => {
|
||||||
|
|
Loading…
Reference in New Issue