From fe0544b9eabd24b398cbf4ef148e3a007a45e246 Mon Sep 17 00:00:00 2001 From: johnhuichen Date: Fri, 5 Jul 2024 18:52:02 +0000 Subject: [PATCH] Add subtract tensor from scalar for ONNX sub op (#1964) --- .gitignore | 3 +++ crates/burn-import/DEVELOPMENT.md | 2 +- .../burn-import/onnx-tests/tests/onnx_tests.rs | 4 ++-- .../burn-import/onnx-tests/tests/sub/sub.onnx | Bin 519 -> 586 bytes crates/burn-import/onnx-tests/tests/sub/sub.py | 10 +++++++--- .../onnx-tests/tests/sub/sub_int.onnx | Bin 365 -> 432 bytes .../burn-import/onnx-tests/tests/sub/sub_int.py | 10 +++++++--- crates/burn-import/src/burn/node/binary.rs | 1 + crates/onnx-ir/src/dim_inference.rs | 16 +++++++++++++++- 9 files changed, 36 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 9d5f51b61..ffa113f25 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ target .vs .fleet .ipynb_checkpoints/ + +# Generated IR and Burn Graph from ONNX +out diff --git a/crates/burn-import/DEVELOPMENT.md b/crates/burn-import/DEVELOPMENT.md index 687428fe1..2d122a042 100644 --- a/crates/burn-import/DEVELOPMENT.md +++ b/crates/burn-import/DEVELOPMENT.md @@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: the Burn model in Rust code, and `my-model.json` includes the model data. 7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. - Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md). + Further details can be found in the [onnx-tests README](./onnx-tests/README.md). ## Testing diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index ac780b6d0..d010d81f1 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -147,7 +147,7 @@ mod tests { let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]], &device); let scalar = 3.0f64; let output = model.forward(input, scalar); - let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]); + let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]); output.to_data().assert_eq(&expected, true); } @@ -162,7 +162,7 @@ mod tests { let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]], &device); let scalar = 3; let output = model.forward(input, scalar); - let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]); + let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]); output.to_data().assert_eq(&expected, true); } diff --git a/crates/burn-import/onnx-tests/tests/sub/sub.onnx b/crates/burn-import/onnx-tests/tests/sub/sub.onnx index 7ffdfc8083cccc46a0c68316488cd8034846bc1a..60d76ea1c55c212fdc369e578204d7ab7cd2f960 100644 GIT binary patch delta 140 zcmZo?ImIH%!6C$6P+5{+l$;^OYNThXXJECFWus^qBct8q97e@he*NImqG+GpV diff --git a/crates/burn-import/onnx-tests/tests/sub/sub_int.py b/crates/burn-import/onnx-tests/tests/sub/sub_int.py index 17ace09ca..487c66b19 100755 --- a/crates/burn-import/onnx-tests/tests/sub/sub_int.py +++ b/crates/burn-import/onnx-tests/tests/sub/sub_int.py @@ -27,6 +27,9 @@ class Model(nn.Module): # Sutract a scalar from a tensor x = x - d + # Sutract a tensor from a scalar + x = d - x + return x @@ -41,8 +44,9 @@ def main(): test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device) scalar = 3 - torch.onnx.export(model, (test_input, scalar), onnx_name, - verbose=False, opset_version=16) + torch.onnx.export( + model, (test_input, scalar), onnx_name, verbose=False, opset_version=16 + ) print("Finished exporting model to {}".format(onnx_name)) @@ -51,5 +55,5 @@ def main(): print("Test output data: {}".format(output)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index da6b7b930..37983f4c3 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -131,6 +131,7 @@ impl BinaryNode { (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) }, _ => panic!("Subtraction is supported for tensor and scalar only"), }; diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index e39ef73bd..2e4f67676 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Slice => slice_update_outputs(node), NodeType::Softmax => same_as_input(node), NodeType::Sqrt => same_as_input(node), - NodeType::Sub => same_as_input(node), + NodeType::Sub => sub_update_outputs(node), NodeType::Sum => same_as_input(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), @@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) { } } +fn sub_update_outputs(node: &mut Node) { + node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) { + (ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs), + (ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs), + (ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs), + // Support broadcasting for lhs/rhs + (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs), + (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs), + _ => { + panic!("Only tensor-scalar inputs are valid."); + } + }; +} + /// Update the output tensor dimension based on the "axes" attribute or the second input fn unsqueeze_update_output(node: &mut Node) { let axes = if node.inputs.len() == 2 {