Feature/codegen gather indices greater than rank 1 (#2199)

* implemented muli-dim index for GatherNode

The `NodeCodegen` impl for `GatherNode` now performs gather in complete
accordance with the ONNX Gather spec.
- a `gather` function was added to the gather.rs file
- `gather()` is now called within the codegen instead of `tensor.select()`
- a test with two test cases have been added
    - test axes 0 and 1
    - both use 2D index tensors

* add gather_onnx to numeric api

Added int and float implementations of gather to the burn-tensor numeric
api:
- named the methods `gather_onnx` to not be confused with the current
  `gather`
- these implementations follow the `Gather` ONNX spec

Updated the gather*.py variants and their onnx outputs

* modified files didn't end up in last commit

* tests passing for onnx gather

The implementation of gather for the ONNX `Gather` spec is tentatively
complete:
- py test models are updated
- onnx_tests are modified and passing: `gather`, `gather_scalar`, and
  `gather_shape`
- node/gather tests are passing

NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test
the actual functionality of gather are likely to be deleted, since they
are redundant to the tests in
crates/burn-import/onnx-tests/tests/onnx_tests.rs.

* inlined onnx gather within codegen

* rm gather_onnx from public api; rm unnecessary tests

* add comments to gather py models

* some codegen changes; formatting to appease run-checks

- Some necessary changes and improvements to the codegen inlined code
  after translating from public api (removed in previous commit).
- Changed some formatting that run-checks complained about.

* simplify gather codegen; include 1d and 2d onnx tests

Modified the `Gather` codegen per requested changes:
- combined match statements on index
- remove use of `alloc::vec::Vec`
- use map -> collect instead of procedural
- include a 1d index gather onnx test
- remove superflous tests

* delete unused gather.onnx
This commit is contained in:
AlteredOxide 2024-08-28 04:51:19 -07:00 committed by GitHub
parent 795201dcfc
commit 0292967000
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 318 additions and 119 deletions

View File

@ -32,7 +32,8 @@ fn main() {
.input("tests/exp/exp.onnx") .input("tests/exp/exp.onnx")
.input("tests/expand/expand.onnx") .input("tests/expand/expand.onnx")
.input("tests/flatten/flatten.onnx") .input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx") .input("tests/gather/gather_1d_idx.onnx")
.input("tests/gather/gather_2d_idx.onnx")
.input("tests/gather/gather_scalar.onnx") .input("tests/gather/gather_scalar.onnx")
.input("tests/gather/gather_shape.onnx") .input("tests/gather/gather_shape.onnx")
.input("tests/gather_elements/gather_elements.onnx") .input("tests/gather_elements/gather_elements.onnx")

View File

@ -1,18 +0,0 @@
pytorch2.1.1:¤
A
onnx::Gather_0
onnx::Gather_12/Gather"Gather*
axis 
main_graphZ
onnx::Gather_0


Z
onnx::Gather_1

b
2


B

View File

@ -1,47 +0,0 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/gather/gather.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, index):
gathered = torch.index_select(x, 1, index)
return gathered
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gather.onnx"
dummy_input = torch.randn(2, 3, device=device)
dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64)
torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = torch.tensor([0, 2], dtype=torch.int64)
print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/gather/gather.onnx
# There is no current support for `Split`, and the `for` loop over the indices
# results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.
import onnx
def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=1
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[2]
),
),
],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 2]
),
)
]),
)
def main():
onnx_model = build_model()
file_name = "gather_1d_idx.onnx"
# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, file_name)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/gather/gather.onnx
# There is no current support for `Split`, and the `for` loop over the indices
# results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.
import onnx
def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=0
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[2, 2]
),
),
],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 2, 2]
),
)
]),
)
def main():
onnx_model = build_model()
file_name = "gather_2d_idx.onnx"
# Ensure valid ONNX:
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, file_name)
if __name__ == '__main__':
main()

View File

@ -2,45 +2,60 @@
# used to generate model: onnx-tests/tests/gather/gather_scalar.onnx # used to generate model: onnx-tests/tests/gather/gather_scalar.onnx
import torch # There is no current support for `Split`, and the `for` loop over the indices
import torch.nn as nn # results in a `Split` node in the ONNX model.
# Therefore, this model is built and exported using ONNX directly.
import onnx
class Model(nn.Module): def build_model():
def __init__(self): return onnx.helper.make_model(
super(Model, self).__init__() ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"Gather",
inputs=["input1", "input2"],
outputs=["output1"],
name="/Gather",
axis=0
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
),
),
onnx.helper.make_value_info(
name="input2",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[]
),
),
def forward(self, x, index): ],
gathered = torch.select(x, 0, index) outputs=[
return gathered onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[3]
),
)
]),
)
def main(): def main():
# Set random seed for reproducibility onnx_model = build_model()
torch.manual_seed(0) file_name = "gather_scalar.onnx"
# Export to onnx # Ensure valid ONNX:
model = Model() onnx.checker.check_model(onnx_model)
model.eval()
device = torch.device("cpu")
onnx_name = "gather_scalar.onnx"
dummy_input = torch.randn(2, 3, device=device) onnx.save(onnx_model, file_name)
dummy_index = 0
torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = 0
print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -33,7 +33,7 @@ def build_model():
onnx.helper.make_value_info( onnx.helper.make_value_info(
name="input1", name="input1",
type_proto=onnx.helper.make_tensor_type_proto( type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2,3] elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
), ),
), ),
onnx.helper.make_value_info( onnx.helper.make_value_info(
@ -66,4 +66,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -41,7 +41,8 @@ include_models!(
exp, exp,
expand, expand,
flatten, flatten,
gather, gather_1d_idx,
gather_2d_idx,
gather_scalar, gather_scalar,
gather_shape, gather_shape,
gather_elements, gather_elements,
@ -451,15 +452,29 @@ mod tests {
} }
#[test] #[test]
fn gather() { fn gather_1d_idx() {
let model: gather::Model<Backend> = gather::Model::default(); let model: gather_1d_idx::Model<Backend> = gather_1d_idx::Model::default();
let device = Default::default(); let device = Default::default();
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device); let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device); let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device);
let output = model.forward(input, index);
let expected = TensorData::from([[1f32, 3.], [4., 6.]]); let expected = TensorData::from([[1f32, 3.], [4., 6.]]);
let output = model.forward(input, index);
assert_eq!(output.to_data(), expected);
}
#[test]
fn gather_2d_idx() {
let model: gather_2d_idx::Model<Backend> = gather_2d_idx::Model::default();
let device = Default::default();
let input = Tensor::<Backend, 2>::from_data([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &device);
let index = Tensor::<Backend, 2, Int>::from_data([[0, 1], [1, 2]], &device);
let expected = TensorData::from([[[1f32, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]]);
let output = model.forward(input, index);
assert_eq!(output.to_data(), expected); assert_eq!(output.to_data(), expected);
} }

View File

@ -27,6 +27,12 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
node_position: usize, node_position: usize,
) -> proc_macro2::TokenStream { ) -> proc_macro2::TokenStream {
let dim = self.dim.to_tokens(); let dim = self.dim.to_tokens();
let input_rank = match &self.input {
Type::Tensor(in_tensor) => in_tensor.dim,
Type::Shape(_) => 1,
_ => panic!("Gather needs Tensor or Shape input, got {:?}!", self.input),
};
let input = match &self.input { let input = match &self.input {
Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position), Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position),
Type::Shape(in_shape) => { Type::Shape(in_shape) => {
@ -34,10 +40,11 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
// To copy just the values from the shape value without moving it // To copy just the values from the shape value without moving it
// (which could lead to ownership problems if the same Shape is used multiple times) // (which could lead to ownership problems if the same Shape is used multiple times)
// borrow the array as a slice and use that to create the Tensor: // borrow the array as a slice and use that to create the Tensor:
quote! { Tensor::from_data(&#in_shape_name as &[_], &*self.device) } quote! { Tensor::<B, 1, Int>::from_data(&#in_shape_name as &[_], &*self.device) }
} }
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input), _ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
}; };
let output = &self.output.name; let output = &self.output.name;
match &self.index { match &self.index {
@ -46,14 +53,41 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select, // convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
// then squeeze the dimension to reduce the rank // then squeeze the dimension to reduce the rank
let index = &idx_scalar.name; let index = &idx_scalar.name;
let output_rank = input_rank - 1;
quote! { quote! {
let #output = #input.select(#dim, Tensor::from_data([#index], &*self.device)).squeeze(#dim); let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
let slice = Tensor::select(#input, #dim, indices);
let #output = slice.squeeze::<#output_rank>(#dim);
} }
} }
Type::Tensor(idx_tensor) => { Type::Tensor(idx_tensor) => {
let index = scope.tensor_use_owned(idx_tensor, node_position); let index = scope.tensor_use_owned(idx_tensor, node_position);
quote! { let index_rank = idx_tensor.dim;
let #output = #input.select(#dim, #index); let output_rank = index_rank + input_rank - 1;
match index_rank {
1 => quote! {
let indices = #index;
let #output = Tensor::select(#input, #dim, indices);
},
_ => quote! {
let indices = #index;
let n_dims = indices.dims().len();
let index_flat = match n_dims {
1 => indices.reshape([1, -1]),
n if n >= 2 => indices.flatten::<2>(0, n - 2),
_ => panic!("Number of dimensions must be greater than 0"),
};
let out = index_flat
.iter_dim(0)
.map(|idxs| {
let idxs = idxs.squeeze::<1>(0);
Tensor::select(#input.clone(), #dim, idxs)
})
.collect();
let #output = Tensor::stack::<#output_rank>(out, #dim);
},
} }
} }
_ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index), _ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index),
@ -78,7 +112,7 @@ mod tests {
}; };
#[test] #[test]
fn test_codegen_gather() { fn test_codegen_gather_1d_idx() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(GatherNode::new( graph.register(GatherNode::new(
@ -121,8 +155,77 @@ mod tests {
tensor1: Tensor<B, 2>, tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 1, Int> tensor2: Tensor<B, 1, Int>
) -> Tensor<B, 2> { ) -> Tensor<B, 2> {
let tensor3 = tensor1.select(0, tensor2); let indices = tensor2;
let tensor3 = Tensor::select(tensor1, 0, indices);
tensor3
}
}
};
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_gather_2d_idx() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(GatherNode::new(
Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Tensor(TensorType::new_int("tensor2", 2)),
TensorType::new_float("tensor3", 3),
0,
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>
) -> Tensor<B, 3> {
let indices = tensor2;
let n_dims = indices.dims().len();
let index_flat = match n_dims {
1 => indices.reshape([1, -1]),
n if n >= 2 => indices.flatten::<2>(0, n - 2),
_ => panic!("Number of dimensions must be greater than 0"),
};
let out = index_flat
.iter_dim(0)
.map(|idxs| {
let idxs = idxs.squeeze::<1>(0);
Tensor::select(tensor1.clone(), 0, idxs)
})
.collect();
let tensor3 = Tensor::stack::<3usize>(out, 0);
tensor3 tensor3
} }
} }
@ -138,7 +241,7 @@ mod tests {
graph.register(GatherNode::new( graph.register(GatherNode::new(
Type::Shape(ShapeType::new("shape1", 3)), Type::Shape(ShapeType::new("shape1", 3)),
Type::Tensor(TensorType::new_int("tensor1", 1)), Type::Tensor(TensorType::new_int("tensor1", 1)),
TensorType::new_float("tensor2", 2), TensorType::new_int("tensor2", 1),
0, 0,
)); ));
@ -174,8 +277,14 @@ mod tests {
&self, &self,
shape1: [usize; 3], shape1: [usize; 3],
tensor1: Tensor<B, 1, Int> tensor1: Tensor<B, 1, Int>
) -> Tensor<B, 2> { ) -> Tensor<B, 1, Int> {
let tensor2 = Tensor::from_data(&shape1 as &[_], &*self.device).select(0, tensor1); let indices = tensor1;
let tensor2 = Tensor::select(
Tensor::<B, 1, Int>::from_data(&shape1 as &[_], &*self.device),
0,
indices,
);
tensor2 tensor2
} }
@ -192,7 +301,7 @@ mod tests {
graph.register(GatherNode::new( graph.register(GatherNode::new(
Type::Tensor(TensorType::new_float("tensor1", 2)), Type::Tensor(TensorType::new_float("tensor1", 2)),
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)), Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
TensorType::new_float("tensor2", 2), TensorType::new_float("tensor2", 1),
0, 0,
)); ));
@ -227,8 +336,11 @@ mod tests {
&self, &self,
tensor1: Tensor<B, 2>, tensor1: Tensor<B, 2>,
scalar1: i64 scalar1: i64
) -> Tensor<B, 2> { ) -> Tensor<B, 1> {
let tensor2 = tensor1.select(0, Tensor::from_data([scalar1], &*self.device)).squeeze(0); let indices = Tensor::<B, 1, _>::from_data([scalar1], &*self.device);
let slice = Tensor::select(tensor1, 0, indices);
let tensor2 = slice.squeeze::<1usize>(0);
tensor2 tensor2
} }

View File

@ -892,6 +892,7 @@ impl TensorCheck {
check check
} }
pub(crate) fn check_prelu_shape<const D: usize>( pub(crate) fn check_prelu_shape<const D: usize>(
shape_tensor: &Shape<D>, shape_tensor: &Shape<D>,
shape_weight: &Shape<1>, shape_weight: &Shape<1>,

View File

@ -816,10 +816,6 @@ fn gather_update_outputs(node: &mut Node) {
_ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty), _ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty),
}; };
if indices_dim > 1 {
panic!("Gather: indices tensor rank above 1 not supported")
}
match &node.inputs[0].ty { match &node.inputs[0].ty {
ArgType::Tensor(input_tensor) => { ArgType::Tensor(input_tensor) => {
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input