mirror of https://github.com/tracel-ai/burn.git
Add subtract tensor from scalar for ONNX sub op (#1964)
This commit is contained in:
parent
1ad2a63f28
commit
fe0544b9ea
|
@ -9,3 +9,6 @@ target
|
|||
.vs
|
||||
.fleet
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Generated IR and Burn Graph from ONNX
|
||||
out
|
||||
|
|
|
@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
|
|||
the Burn model in Rust code, and `my-model.json` includes the model data.
|
||||
|
||||
7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`.
|
||||
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md).
|
||||
Further details can be found in the [onnx-tests README](./onnx-tests/README.md).
|
||||
|
||||
## Testing
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ mod tests {
|
|||
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device);
|
||||
let scalar = 3.0f64;
|
||||
let output = model.forward(input, scalar);
|
||||
let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]);
|
||||
let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ mod tests {
|
|||
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
|
||||
let scalar = 3;
|
||||
let output = model.forward(input, scalar);
|
||||
let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]);
|
||||
let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -26,6 +26,9 @@ class Model(nn.Module):
|
|||
# Sutract a scalar from a tensor
|
||||
x = x - d
|
||||
|
||||
# Sutract a tensor from a scalar
|
||||
x = d - x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -40,8 +43,9 @@ def main():
|
|||
|
||||
scalar = 3.0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
torch.onnx.export(
|
||||
model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16
|
||||
)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
|
@ -53,5 +57,5 @@ def main():
|
|||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Binary file not shown.
|
@ -27,6 +27,9 @@ class Model(nn.Module):
|
|||
# Sutract a scalar from a tensor
|
||||
x = x - d
|
||||
|
||||
# Sutract a tensor from a scalar
|
||||
x = d - x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -41,8 +44,9 @@ def main():
|
|||
test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
|
||||
scalar = 3
|
||||
|
||||
torch.onnx.export(model, (test_input, scalar), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
torch.onnx.export(
|
||||
model, (test_input, scalar), onnx_name, verbose=False, opset_version=16
|
||||
)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
|
@ -51,5 +55,5 @@ def main():
|
|||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -131,6 +131,7 @@ impl BinaryNode {
|
|||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
|
||||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
|
||||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
|
||||
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) },
|
||||
_ => panic!("Subtraction is supported for tensor and scalar only"),
|
||||
};
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) {
|
|||
NodeType::Slice => slice_update_outputs(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Sub => same_as_input(node),
|
||||
NodeType::Sub => sub_update_outputs(node),
|
||||
NodeType::Sum => same_as_input(node),
|
||||
NodeType::Tanh => same_as_input(node),
|
||||
NodeType::Transpose => same_as_input(node),
|
||||
|
@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
fn sub_update_outputs(node: &mut Node) {
|
||||
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
|
||||
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
|
||||
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
|
||||
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
|
||||
// Support broadcasting for lhs/rhs
|
||||
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
|
||||
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
|
||||
_ => {
|
||||
panic!("Only tensor-scalar inputs are valid.");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||
fn unsqueeze_update_output(node: &mut Node) {
|
||||
let axes = if node.inputs.len() == 2 {
|
||||
|
|
Loading…
Reference in New Issue