Add subtract tensor from scalar for ONNX sub op (#1964)

This commit is contained in:
johnhuichen 2024-07-05 18:52:02 +00:00 committed by GitHub
parent 1ad2a63f28
commit fe0544b9ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 36 additions and 10 deletions

3
.gitignore vendored
View File

@ -9,3 +9,6 @@ target
.vs .vs
.fleet .fleet
.ipynb_checkpoints/ .ipynb_checkpoints/
# Generated IR and Burn Graph from ONNX
out

View File

@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
the Burn model in Rust code, and `my-model.json` includes the model data. the Burn model in Rust code, and `my-model.json` includes the model data.
7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. 7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`.
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md). Further details can be found in the [onnx-tests README](./onnx-tests/README.md).
## Testing ## Testing

View File

@ -147,7 +147,7 @@ mod tests {
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device); let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device);
let scalar = 3.0f64; let scalar = 3.0f64;
let output = model.forward(input, scalar); let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]); let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]);
output.to_data().assert_eq(&expected, true); output.to_data().assert_eq(&expected, true);
} }
@ -162,7 +162,7 @@ mod tests {
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device); let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
let scalar = 3; let scalar = 3;
let output = model.forward(input, scalar); let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]); let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]);
output.to_data().assert_eq(&expected, true); output.to_data().assert_eq(&expected, true);
} }

View File

@ -26,6 +26,9 @@ class Model(nn.Module):
# Sutract a scalar from a tensor # Sutract a scalar from a tensor
x = x - d x = x - d
# Sutract a tensor from a scalar
x = d - x
return x return x
@ -40,8 +43,9 @@ def main():
scalar = 3.0 scalar = 3.0
torch.onnx.export(model, (dummy_input, scalar), onnx_name, torch.onnx.export(
verbose=False, opset_version=16) model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16
)
print("Finished exporting model to {}".format(onnx_name)) print("Finished exporting model to {}".format(onnx_name))
@ -53,5 +57,5 @@ def main():
print("Test output data: {}".format(output)) print("Test output data: {}".format(output))
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -27,6 +27,9 @@ class Model(nn.Module):
# Sutract a scalar from a tensor # Sutract a scalar from a tensor
x = x - d x = x - d
# Sutract a tensor from a scalar
x = d - x
return x return x
@ -41,8 +44,9 @@ def main():
test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device) test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
scalar = 3 scalar = 3
torch.onnx.export(model, (test_input, scalar), onnx_name, torch.onnx.export(
verbose=False, opset_version=16) model, (test_input, scalar), onnx_name, verbose=False, opset_version=16
)
print("Finished exporting model to {}".format(onnx_name)) print("Finished exporting model to {}".format(onnx_name))
@ -51,5 +55,5 @@ def main():
print("Test output data: {}".format(output)) print("Test output data: {}".format(output))
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -131,6 +131,7 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) },
_ => panic!("Subtraction is supported for tensor and scalar only"), _ => panic!("Subtraction is supported for tensor and scalar only"),
}; };

View File

@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Slice => slice_update_outputs(node), NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node), NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node), NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node), NodeType::Sub => sub_update_outputs(node),
NodeType::Sum => same_as_input(node), NodeType::Sum => same_as_input(node),
NodeType::Tanh => same_as_input(node), NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node), NodeType::Transpose => same_as_input(node),
@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) {
} }
} }
fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
// Support broadcasting for lhs/rhs
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
_ => {
panic!("Only tensor-scalar inputs are valid.");
}
};
}
/// Update the output tensor dimension based on the "axes" attribute or the second input /// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) { fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 { let axes = if node.inputs.len() == 2 {