diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 2defcbe30..23312bec9 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -32,7 +32,8 @@ fn main() { .input("tests/exp/exp.onnx") .input("tests/expand/expand.onnx") .input("tests/flatten/flatten.onnx") - .input("tests/gather/gather.onnx") + .input("tests/gather/gather_1d_idx.onnx") + .input("tests/gather/gather_2d_idx.onnx") .input("tests/gather/gather_scalar.onnx") .input("tests/gather/gather_shape.onnx") .input("tests/gather_elements/gather_elements.onnx") diff --git a/crates/burn-import/onnx-tests/tests/gather/gather.onnx b/crates/burn-import/onnx-tests/tests/gather/gather.onnx deleted file mode 100644 index 9589d8410..000000000 --- a/crates/burn-import/onnx-tests/tests/gather/gather.onnx +++ /dev/null @@ -1,18 +0,0 @@ -pytorch2.1.1:¤ -A -onnx::Gather_0 -onnx::Gather_12/Gather"Gather* -axis  -main_graphZ -onnx::Gather_0 -  - -Z -onnx::Gather_1 - - -b -2 -  - -B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/gather/gather.py b/crates/burn-import/onnx-tests/tests/gather/gather.py deleted file mode 100644 index 39688d34d..000000000 --- a/crates/burn-import/onnx-tests/tests/gather/gather.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 - -# used to generate model: onnx-tests/tests/gather/gather.onnx - -import torch -import torch.nn as nn - - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - - def forward(self, x, index): - gathered = torch.index_select(x, 1, index) - return gathered - - -def main(): - # Set random seed for reproducibility - torch.manual_seed(0) - - # Export to onnx - model = Model() - model.eval() - device = torch.device("cpu") - onnx_name = "gather.onnx" - - dummy_input = torch.randn(2, 3, device=device) - dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64) - - torch.onnx.export(model, (dummy_input, dummy_index), onnx_name, - verbose=False, opset_version=16) - - print("Finished exporting model to {}".format(onnx_name)) - - # Output some test data for use in the test - test_input = torch.tensor([[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0]]) - test_index = torch.tensor([0, 2], dtype=torch.int64) - - print("Test input data: {}, {}".format(test_input, test_index)) - output = model.forward(test_input, test_index) - print("Test output data: {}".format(output)) - - -if __name__ == '__main__': - main() diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.onnx b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.onnx new file mode 100644 index 000000000..97b0ddefe Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py new file mode 100644 index 000000000..b4e4a3bd1 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/gather/gather_1d_idx.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/gather/gather.onnx + +# There is no current support for `Split`, and the `for` loop over the indices +# results in a `Split` node in the ONNX model. +# Therefore, this model is built and exported using ONNX directly. + +import onnx + + +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)], + graph=onnx.helper.make_graph(name="main_graph", nodes=[ + onnx.helper.make_node( + "Gather", + inputs=["input1", "input2"], + outputs=["output1"], + name="/Gather", + axis=1 + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3] + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT64, shape=[2] + ), + ), + + ], + outputs=[ + onnx.helper.make_value_info( + name="output1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 2] + ), + ) + ]), + ) + + +def main(): + onnx_model = build_model() + file_name = "gather_1d_idx.onnx" + + # Ensure valid ONNX: + onnx.checker.check_model(onnx_model) + + onnx.save(onnx_model, file_name) + + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.onnx b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.onnx new file mode 100644 index 000000000..ff64b029d Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py new file mode 100644 index 000000000..767168eef --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/gather/gather_2d_idx.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/gather/gather.onnx + +# There is no current support for `Split`, and the `for` loop over the indices +# results in a `Split` node in the ONNX model. +# Therefore, this model is built and exported using ONNX directly. + +import onnx + + +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)], + graph=onnx.helper.make_graph(name="main_graph", nodes=[ + onnx.helper.make_node( + "Gather", + inputs=["input1", "input2"], + outputs=["output1"], + name="/Gather", + axis=0 + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3] + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT64, shape=[2, 2] + ), + ), + + ], + outputs=[ + onnx.helper.make_value_info( + name="output1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 2, 2] + ), + ) + ]), + ) + + +def main(): + onnx_model = build_model() + file_name = "gather_2d_idx.onnx" + + # Ensure valid ONNX: + onnx.checker.check_model(onnx_model) + + onnx.save(onnx_model, file_name) + + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx b/crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx index 7afca04a8..a8a5586c0 100644 Binary files a/crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx and b/crates/burn-import/onnx-tests/tests/gather/gather_scalar.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_scalar.py b/crates/burn-import/onnx-tests/tests/gather/gather_scalar.py index b6dabef62..70373762e 100644 --- a/crates/burn-import/onnx-tests/tests/gather/gather_scalar.py +++ b/crates/burn-import/onnx-tests/tests/gather/gather_scalar.py @@ -2,45 +2,60 @@ # used to generate model: onnx-tests/tests/gather/gather_scalar.onnx -import torch -import torch.nn as nn +# There is no current support for `Split`, and the `for` loop over the indices +# results in a `Split` node in the ONNX model. +# Therefore, this model is built and exported using ONNX directly. + +import onnx -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() +def build_model(): + return onnx.helper.make_model( + ir_version=8, + opset_imports=[onnx.helper.make_operatorsetid("", 16)], + graph=onnx.helper.make_graph(name="main_graph", nodes=[ + onnx.helper.make_node( + "Gather", + inputs=["input1", "input2"], + outputs=["output1"], + name="/Gather", + axis=0 + ), + ], + inputs=[ + onnx.helper.make_value_info( + name="input1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3] + ), + ), + onnx.helper.make_value_info( + name="input2", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.INT64, shape=[] + ), + ), - def forward(self, x, index): - gathered = torch.select(x, 0, index) - return gathered + ], + outputs=[ + onnx.helper.make_value_info( + name="output1", + type_proto=onnx.helper.make_tensor_type_proto( + elem_type=onnx.TensorProto.FLOAT, shape=[3] + ), + ) + ]), + ) def main(): - # Set random seed for reproducibility - torch.manual_seed(0) + onnx_model = build_model() + file_name = "gather_scalar.onnx" - # Export to onnx - model = Model() - model.eval() - device = torch.device("cpu") - onnx_name = "gather_scalar.onnx" + # Ensure valid ONNX: + onnx.checker.check_model(onnx_model) - dummy_input = torch.randn(2, 3, device=device) - dummy_index = 0 - - torch.onnx.export(model, (dummy_input, dummy_index), onnx_name, - verbose=False, opset_version=16) - - print("Finished exporting model to {}".format(onnx_name)) - - # Output some test data for use in the test - test_input = torch.tensor([[1.0, 2.0, 3.0], - [4.0, 5.0, 6.0]]) - test_index = 0 - - print("Test input data: {}, {}".format(test_input, test_index)) - output = model.forward(test_input, test_index) - print("Test output data: {}".format(output)) + onnx.save(onnx_model, file_name) if __name__ == '__main__': diff --git a/crates/burn-import/onnx-tests/tests/gather/gather_shape.py b/crates/burn-import/onnx-tests/tests/gather/gather_shape.py index 437bd34a8..8fe0e4541 100644 --- a/crates/burn-import/onnx-tests/tests/gather/gather_shape.py +++ b/crates/burn-import/onnx-tests/tests/gather/gather_shape.py @@ -33,7 +33,7 @@ def build_model(): onnx.helper.make_value_info( name="input1", type_proto=onnx.helper.make_tensor_type_proto( - elem_type=onnx.TensorProto.FLOAT, shape=[2,3] + elem_type=onnx.TensorProto.FLOAT, shape=[2, 3] ), ), onnx.helper.make_value_info( @@ -66,4 +66,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index e9959e819..9f8dcb448 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -41,7 +41,8 @@ include_models!( exp, expand, flatten, - gather, + gather_1d_idx, + gather_2d_idx, gather_scalar, gather_shape, gather_elements, @@ -451,15 +452,29 @@ mod tests { } #[test] - fn gather() { - let model: gather::Model = gather::Model::default(); + fn gather_1d_idx() { + let model: gather_1d_idx::Model = gather_1d_idx::Model::default(); let device = Default::default(); let input = Tensor::::from_floats([[1., 2., 3.], [4., 5., 6.]], &device); let index = Tensor::::from_ints([0, 2], &device); - let output = model.forward(input, index); let expected = TensorData::from([[1f32, 3.], [4., 6.]]); + let output = model.forward(input, index); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn gather_2d_idx() { + let model: gather_2d_idx::Model = gather_2d_idx::Model::default(); + + let device = Default::default(); + + let input = Tensor::::from_data([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &device); + let index = Tensor::::from_data([[0, 1], [1, 2]], &device); + let expected = TensorData::from([[[1f32, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]]); + let output = model.forward(input, index); assert_eq!(output.to_data(), expected); } diff --git a/crates/burn-import/src/burn/node/gather.rs b/crates/burn-import/src/burn/node/gather.rs index b79757518..a9b66e1d0 100644 --- a/crates/burn-import/src/burn/node/gather.rs +++ b/crates/burn-import/src/burn/node/gather.rs @@ -27,6 +27,12 @@ impl NodeCodegen for GatherNode { node_position: usize, ) -> proc_macro2::TokenStream { let dim = self.dim.to_tokens(); + let input_rank = match &self.input { + Type::Tensor(in_tensor) => in_tensor.dim, + Type::Shape(_) => 1, + _ => panic!("Gather needs Tensor or Shape input, got {:?}!", self.input), + }; + let input = match &self.input { Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position), Type::Shape(in_shape) => { @@ -34,10 +40,11 @@ impl NodeCodegen for GatherNode { // To copy just the values from the shape value without moving it // (which could lead to ownership problems if the same Shape is used multiple times) // borrow the array as a slice and use that to create the Tensor: - quote! { Tensor::from_data(&#in_shape_name as &[_], &*self.device) } + quote! { Tensor::::from_data(&#in_shape_name as &[_], &*self.device) } } _ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input), }; + let output = &self.output.name; match &self.index { @@ -46,14 +53,41 @@ impl NodeCodegen for GatherNode { // convert the 0-D index to a 1-D Tensor with len 1 to use burn's select, // then squeeze the dimension to reduce the rank let index = &idx_scalar.name; + let output_rank = input_rank - 1; quote! { - let #output = #input.select(#dim, Tensor::from_data([#index], &*self.device)).squeeze(#dim); + let indices = Tensor::::from_data([#index], &*self.device); + let slice = Tensor::select(#input, #dim, indices); + let #output = slice.squeeze::<#output_rank>(#dim); } } Type::Tensor(idx_tensor) => { let index = scope.tensor_use_owned(idx_tensor, node_position); - quote! { - let #output = #input.select(#dim, #index); + let index_rank = idx_tensor.dim; + let output_rank = index_rank + input_rank - 1; + match index_rank { + 1 => quote! { + let indices = #index; + let #output = Tensor::select(#input, #dim, indices); + }, + _ => quote! { + let indices = #index; + + let n_dims = indices.dims().len(); + let index_flat = match n_dims { + 1 => indices.reshape([1, -1]), + n if n >= 2 => indices.flatten::<2>(0, n - 2), + _ => panic!("Number of dimensions must be greater than 0"), + }; + + let out = index_flat + .iter_dim(0) + .map(|idxs| { + let idxs = idxs.squeeze::<1>(0); + Tensor::select(#input.clone(), #dim, idxs) + }) + .collect(); + let #output = Tensor::stack::<#output_rank>(out, #dim); + }, } } _ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index), @@ -78,7 +112,7 @@ mod tests { }; #[test] - fn test_codegen_gather() { + fn test_codegen_gather_1d_idx() { let mut graph = BurnGraph::::default(); graph.register(GatherNode::new( @@ -121,8 +155,77 @@ mod tests { tensor1: Tensor, tensor2: Tensor ) -> Tensor { - let tensor3 = tensor1.select(0, tensor2); + let indices = tensor2; + let tensor3 = Tensor::select(tensor1, 0, indices); + tensor3 + } + } + }; + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_gather_2d_idx() { + let mut graph = BurnGraph::::default(); + + graph.register(GatherNode::new( + Type::Tensor(TensorType::new_float("tensor1", 2)), + Type::Tensor(TensorType::new_int("tensor2", 2)), + TensorType::new_float("tensor3", 3), + 0, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let indices = tensor2; + + let n_dims = indices.dims().len(); + let index_flat = match n_dims { + 1 => indices.reshape([1, -1]), + n if n >= 2 => indices.flatten::<2>(0, n - 2), + _ => panic!("Number of dimensions must be greater than 0"), + }; + + let out = index_flat + .iter_dim(0) + .map(|idxs| { + let idxs = idxs.squeeze::<1>(0); + Tensor::select(tensor1.clone(), 0, idxs) + }) + .collect(); + let tensor3 = Tensor::stack::<3usize>(out, 0); tensor3 } } @@ -138,7 +241,7 @@ mod tests { graph.register(GatherNode::new( Type::Shape(ShapeType::new("shape1", 3)), Type::Tensor(TensorType::new_int("tensor1", 1)), - TensorType::new_float("tensor2", 2), + TensorType::new_int("tensor2", 1), 0, )); @@ -174,8 +277,14 @@ mod tests { &self, shape1: [usize; 3], tensor1: Tensor - ) -> Tensor { - let tensor2 = Tensor::from_data(&shape1 as &[_], &*self.device).select(0, tensor1); + ) -> Tensor { + let indices = tensor1; + + let tensor2 = Tensor::select( + Tensor::::from_data(&shape1 as &[_], &*self.device), + 0, + indices, + ); tensor2 } @@ -192,7 +301,7 @@ mod tests { graph.register(GatherNode::new( Type::Tensor(TensorType::new_float("tensor1", 2)), Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)), - TensorType::new_float("tensor2", 2), + TensorType::new_float("tensor2", 1), 0, )); @@ -227,8 +336,11 @@ mod tests { &self, tensor1: Tensor, scalar1: i64 - ) -> Tensor { - let tensor2 = tensor1.select(0, Tensor::from_data([scalar1], &*self.device)).squeeze(0); + ) -> Tensor { + let indices = Tensor::::from_data([scalar1], &*self.device); + + let slice = Tensor::select(tensor1, 0, indices); + let tensor2 = slice.squeeze::<1usize>(0); tensor2 } diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 871a51e4d..b62720c10 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -892,6 +892,7 @@ impl TensorCheck { check } + pub(crate) fn check_prelu_shape( shape_tensor: &Shape, shape_weight: &Shape<1>, diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 646887128..97b251e0a 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -816,10 +816,6 @@ fn gather_update_outputs(node: &mut Node) { _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), }; - if indices_dim > 1 { - panic!("Gather: indices tensor rank above 1 not supported") - } - match &node.inputs[0].ty { ArgType::Tensor(input_tensor) => { // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input