Add subtract tensor from scalar for ONNX sub op (#1964)

This commit is contained in:
johnhuichen 2024-07-05 18:52:02 +00:00 committed by GitHub
parent 1ad2a63f28
commit fe0544b9ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 36 additions and 10 deletions

3
.gitignore vendored
View File

@ -9,3 +9,6 @@ target
.vs
.fleet
.ipynb_checkpoints/
# Generated IR and Burn Graph from ONNX
out

View File

@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
the Burn model in Rust code, and `my-model.json` includes the model data.
7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`.
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md).
Further details can be found in the [onnx-tests README](./onnx-tests/README.md).
## Testing

View File

@ -147,7 +147,7 @@ mod tests {
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device);
let scalar = 3.0f64;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]);
let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]);
output.to_data().assert_eq(&expected, true);
}
@ -162,7 +162,7 @@ mod tests {
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
let scalar = 3;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]);
let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]);
output.to_data().assert_eq(&expected, true);
}

View File

@ -26,6 +26,9 @@ class Model(nn.Module):
# Sutract a scalar from a tensor
x = x - d
# Sutract a tensor from a scalar
x = d - x
return x
@ -40,8 +43,9 @@ def main():
scalar = 3.0
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16
)
print("Finished exporting model to {}".format(onnx_name))
@ -53,5 +57,5 @@ def main():
print("Test output data: {}".format(output))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -27,6 +27,9 @@ class Model(nn.Module):
# Sutract a scalar from a tensor
x = x - d
# Sutract a tensor from a scalar
x = d - x
return x
@ -41,8 +44,9 @@ def main():
test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
scalar = 3
torch.onnx.export(model, (test_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (test_input, scalar), onnx_name, verbose=False, opset_version=16
)
print("Finished exporting model to {}".format(onnx_name))
@ -51,5 +55,5 @@ def main():
print("Test output data: {}".format(output))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -131,6 +131,7 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) },
_ => panic!("Subtraction is supported for tensor and scalar only"),
};

View File

@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Sub => sub_update_outputs(node),
NodeType::Sum => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) {
}
}
fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
// Support broadcasting for lhs/rhs
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
_ => {
panic!("Only tensor-scalar inputs are valid.");
}
};
}
/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {