mirror of https://github.com/tracel-ai/burn.git
Feature/codegen gather indices greater than rank 1 (#2199)
* implemented muli-dim index for GatherNode The `NodeCodegen` impl for `GatherNode` now performs gather in complete accordance with the ONNX Gather spec. - a `gather` function was added to the gather.rs file - `gather()` is now called within the codegen instead of `tensor.select()` - a test with two test cases have been added - test axes 0 and 1 - both use 2D index tensors * add gather_onnx to numeric api Added int and float implementations of gather to the burn-tensor numeric api: - named the methods `gather_onnx` to not be confused with the current `gather` - these implementations follow the `Gather` ONNX spec Updated the gather*.py variants and their onnx outputs * modified files didn't end up in last commit * tests passing for onnx gather The implementation of gather for the ONNX `Gather` spec is tentatively complete: - py test models are updated - onnx_tests are modified and passing: `gather`, `gather_scalar`, and `gather_shape` - node/gather tests are passing NOTE: The two additional tests in crates/burn-import/src/burn/node/gather.rs that test the actual functionality of gather are likely to be deleted, since they are redundant to the tests in crates/burn-import/onnx-tests/tests/onnx_tests.rs. * inlined onnx gather within codegen * rm gather_onnx from public api; rm unnecessary tests * add comments to gather py models * some codegen changes; formatting to appease run-checks - Some necessary changes and improvements to the codegen inlined code after translating from public api (removed in previous commit). - Changed some formatting that run-checks complained about. * simplify gather codegen; include 1d and 2d onnx tests Modified the `Gather` codegen per requested changes: - combined match statements on index - remove use of `alloc::vec::Vec` - use map -> collect instead of procedural - include a 1d index gather onnx test - remove superflous tests * delete unused gather.onnx
This commit is contained in:
parent
795201dcfc
commit
0292967000
|
@ -32,7 +32,8 @@ fn main() {
|
||||||
.input("tests/exp/exp.onnx")
|
.input("tests/exp/exp.onnx")
|
||||||
.input("tests/expand/expand.onnx")
|
.input("tests/expand/expand.onnx")
|
||||||
.input("tests/flatten/flatten.onnx")
|
.input("tests/flatten/flatten.onnx")
|
||||||
.input("tests/gather/gather.onnx")
|
.input("tests/gather/gather_1d_idx.onnx")
|
||||||
|
.input("tests/gather/gather_2d_idx.onnx")
|
||||||
.input("tests/gather/gather_scalar.onnx")
|
.input("tests/gather/gather_scalar.onnx")
|
||||||
.input("tests/gather/gather_shape.onnx")
|
.input("tests/gather/gather_shape.onnx")
|
||||||
.input("tests/gather_elements/gather_elements.onnx")
|
.input("tests/gather_elements/gather_elements.onnx")
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
pytorch2.1.1:¤
|
|
||||||
A
|
|
||||||
onnx::Gather_0
|
|
||||||
onnx::Gather_12/Gather"Gather*
|
|
||||||
axis
|
|
||||||
main_graphZ
|
|
||||||
onnx::Gather_0
|
|
||||||
|
|
||||||
|
|
||||||
Z
|
|
||||||
onnx::Gather_1
|
|
||||||
|
|
||||||
|
|
||||||
b
|
|
||||||
2
|
|
||||||
|
|
||||||
|
|
||||||
B
|
|
|
@ -1,47 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
# used to generate model: onnx-tests/tests/gather/gather.onnx
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Model, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x, index):
|
|
||||||
gathered = torch.index_select(x, 1, index)
|
|
||||||
return gathered
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Set random seed for reproducibility
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
# Export to onnx
|
|
||||||
model = Model()
|
|
||||||
model.eval()
|
|
||||||
device = torch.device("cpu")
|
|
||||||
onnx_name = "gather.onnx"
|
|
||||||
|
|
||||||
dummy_input = torch.randn(2, 3, device=device)
|
|
||||||
dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64)
|
|
||||||
|
|
||||||
torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
|
|
||||||
verbose=False, opset_version=16)
|
|
||||||
|
|
||||||
print("Finished exporting model to {}".format(onnx_name))
|
|
||||||
|
|
||||||
# Output some test data for use in the test
|
|
||||||
test_input = torch.tensor([[1.0, 2.0, 3.0],
|
|
||||||
[4.0, 5.0, 6.0]])
|
|
||||||
test_index = torch.tensor([0, 2], dtype=torch.int64)
|
|
||||||
|
|
||||||
print("Test input data: {}, {}".format(test_input, test_index))
|
|
||||||
output = model.forward(test_input, test_index)
|
|
||||||
print("Test output data: {}".format(output))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
Binary file not shown.
|
@ -0,0 +1,62 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# used to generate model: onnx-tests/tests/gather/gather.onnx
|
||||||
|
|
||||||
|
# There is no current support for `Split`, and the `for` loop over the indices
|
||||||
|
# results in a `Split` node in the ONNX model.
|
||||||
|
# Therefore, this model is built and exported using ONNX directly.
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
|
def build_model():
|
||||||
|
return onnx.helper.make_model(
|
||||||
|
ir_version=8,
|
||||||
|
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
|
||||||
|
graph=onnx.helper.make_graph(name="main_graph", nodes=[
|
||||||
|
onnx.helper.make_node(
|
||||||
|
"Gather",
|
||||||
|
inputs=["input1", "input2"],
|
||||||
|
outputs=["output1"],
|
||||||
|
name="/Gather",
|
||||||
|
axis=1
|
||||||
|
),
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="input1",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="input2",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.INT64, shape=[2]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="output1",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.FLOAT, shape=[2, 2]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
onnx_model = build_model()
|
||||||
|
file_name = "gather_1d_idx.onnx"
|
||||||
|
|
||||||
|
# Ensure valid ONNX:
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
|
||||||
|
onnx.save(onnx_model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Binary file not shown.
|
@ -0,0 +1,62 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# used to generate model: onnx-tests/tests/gather/gather.onnx
|
||||||
|
|
||||||
|
# There is no current support for `Split`, and the `for` loop over the indices
|
||||||
|
# results in a `Split` node in the ONNX model.
|
||||||
|
# Therefore, this model is built and exported using ONNX directly.
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
|
def build_model():
|
||||||
|
return onnx.helper.make_model(
|
||||||
|
ir_version=8,
|
||||||
|
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
|
||||||
|
graph=onnx.helper.make_graph(name="main_graph", nodes=[
|
||||||
|
onnx.helper.make_node(
|
||||||
|
"Gather",
|
||||||
|
inputs=["input1", "input2"],
|
||||||
|
outputs=["output1"],
|
||||||
|
name="/Gather",
|
||||||
|
axis=0
|
||||||
|
),
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="input1",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="input2",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.INT64, shape=[2, 2]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="output1",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.FLOAT, shape=[2, 2, 2]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
onnx_model = build_model()
|
||||||
|
file_name = "gather_2d_idx.onnx"
|
||||||
|
|
||||||
|
# Ensure valid ONNX:
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
|
||||||
|
onnx.save(onnx_model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Binary file not shown.
|
@ -2,45 +2,60 @@
|
||||||
|
|
||||||
# used to generate model: onnx-tests/tests/gather/gather_scalar.onnx
|
# used to generate model: onnx-tests/tests/gather/gather_scalar.onnx
|
||||||
|
|
||||||
import torch
|
# There is no current support for `Split`, and the `for` loop over the indices
|
||||||
import torch.nn as nn
|
# results in a `Split` node in the ONNX model.
|
||||||
|
# Therefore, this model is built and exported using ONNX directly.
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
def build_model():
|
||||||
def __init__(self):
|
return onnx.helper.make_model(
|
||||||
super(Model, self).__init__()
|
ir_version=8,
|
||||||
|
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
|
||||||
|
graph=onnx.helper.make_graph(name="main_graph", nodes=[
|
||||||
|
onnx.helper.make_node(
|
||||||
|
"Gather",
|
||||||
|
inputs=["input1", "input2"],
|
||||||
|
outputs=["output1"],
|
||||||
|
name="/Gather",
|
||||||
|
axis=0
|
||||||
|
),
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="input1",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
onnx.helper.make_value_info(
|
||||||
|
name="input2",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.INT64, shape=[]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
|
||||||
def forward(self, x, index):
|
],
|
||||||
gathered = torch.select(x, 0, index)
|
outputs=[
|
||||||
return gathered
|
onnx.helper.make_value_info(
|
||||||
|
name="output1",
|
||||||
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
|
elem_type=onnx.TensorProto.FLOAT, shape=[3]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Set random seed for reproducibility
|
onnx_model = build_model()
|
||||||
torch.manual_seed(0)
|
file_name = "gather_scalar.onnx"
|
||||||
|
|
||||||
# Export to onnx
|
# Ensure valid ONNX:
|
||||||
model = Model()
|
onnx.checker.check_model(onnx_model)
|
||||||
model.eval()
|
|
||||||
device = torch.device("cpu")
|
|
||||||
onnx_name = "gather_scalar.onnx"
|
|
||||||
|
|
||||||
dummy_input = torch.randn(2, 3, device=device)
|
onnx.save(onnx_model, file_name)
|
||||||
dummy_index = 0
|
|
||||||
|
|
||||||
torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
|
|
||||||
verbose=False, opset_version=16)
|
|
||||||
|
|
||||||
print("Finished exporting model to {}".format(onnx_name))
|
|
||||||
|
|
||||||
# Output some test data for use in the test
|
|
||||||
test_input = torch.tensor([[1.0, 2.0, 3.0],
|
|
||||||
[4.0, 5.0, 6.0]])
|
|
||||||
test_index = 0
|
|
||||||
|
|
||||||
print("Test input data: {}, {}".format(test_input, test_index))
|
|
||||||
output = model.forward(test_input, test_index)
|
|
||||||
print("Test output data: {}".format(output))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -33,7 +33,7 @@ def build_model():
|
||||||
onnx.helper.make_value_info(
|
onnx.helper.make_value_info(
|
||||||
name="input1",
|
name="input1",
|
||||||
type_proto=onnx.helper.make_tensor_type_proto(
|
type_proto=onnx.helper.make_tensor_type_proto(
|
||||||
elem_type=onnx.TensorProto.FLOAT, shape=[2,3]
|
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
onnx.helper.make_value_info(
|
onnx.helper.make_value_info(
|
||||||
|
@ -66,4 +66,4 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -41,7 +41,8 @@ include_models!(
|
||||||
exp,
|
exp,
|
||||||
expand,
|
expand,
|
||||||
flatten,
|
flatten,
|
||||||
gather,
|
gather_1d_idx,
|
||||||
|
gather_2d_idx,
|
||||||
gather_scalar,
|
gather_scalar,
|
||||||
gather_shape,
|
gather_shape,
|
||||||
gather_elements,
|
gather_elements,
|
||||||
|
@ -451,15 +452,29 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gather() {
|
fn gather_1d_idx() {
|
||||||
let model: gather::Model<Backend> = gather::Model::default();
|
let model: gather_1d_idx::Model<Backend> = gather_1d_idx::Model::default();
|
||||||
|
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
|
|
||||||
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
|
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
|
||||||
let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device);
|
let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device);
|
||||||
let output = model.forward(input, index);
|
|
||||||
let expected = TensorData::from([[1f32, 3.], [4., 6.]]);
|
let expected = TensorData::from([[1f32, 3.], [4., 6.]]);
|
||||||
|
let output = model.forward(input, index);
|
||||||
|
|
||||||
|
assert_eq!(output.to_data(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gather_2d_idx() {
|
||||||
|
let model: gather_2d_idx::Model<Backend> = gather_2d_idx::Model::default();
|
||||||
|
|
||||||
|
let device = Default::default();
|
||||||
|
|
||||||
|
let input = Tensor::<Backend, 2>::from_data([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &device);
|
||||||
|
let index = Tensor::<Backend, 2, Int>::from_data([[0, 1], [1, 2]], &device);
|
||||||
|
let expected = TensorData::from([[[1f32, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]]);
|
||||||
|
let output = model.forward(input, index);
|
||||||
|
|
||||||
assert_eq!(output.to_data(), expected);
|
assert_eq!(output.to_data(), expected);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,12 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
|
||||||
node_position: usize,
|
node_position: usize,
|
||||||
) -> proc_macro2::TokenStream {
|
) -> proc_macro2::TokenStream {
|
||||||
let dim = self.dim.to_tokens();
|
let dim = self.dim.to_tokens();
|
||||||
|
let input_rank = match &self.input {
|
||||||
|
Type::Tensor(in_tensor) => in_tensor.dim,
|
||||||
|
Type::Shape(_) => 1,
|
||||||
|
_ => panic!("Gather needs Tensor or Shape input, got {:?}!", self.input),
|
||||||
|
};
|
||||||
|
|
||||||
let input = match &self.input {
|
let input = match &self.input {
|
||||||
Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position),
|
Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position),
|
||||||
Type::Shape(in_shape) => {
|
Type::Shape(in_shape) => {
|
||||||
|
@ -34,10 +40,11 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
|
||||||
// To copy just the values from the shape value without moving it
|
// To copy just the values from the shape value without moving it
|
||||||
// (which could lead to ownership problems if the same Shape is used multiple times)
|
// (which could lead to ownership problems if the same Shape is used multiple times)
|
||||||
// borrow the array as a slice and use that to create the Tensor:
|
// borrow the array as a slice and use that to create the Tensor:
|
||||||
quote! { Tensor::from_data(&#in_shape_name as &[_], &*self.device) }
|
quote! { Tensor::<B, 1, Int>::from_data(&#in_shape_name as &[_], &*self.device) }
|
||||||
}
|
}
|
||||||
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
|
_ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input),
|
||||||
};
|
};
|
||||||
|
|
||||||
let output = &self.output.name;
|
let output = &self.output.name;
|
||||||
|
|
||||||
match &self.index {
|
match &self.index {
|
||||||
|
@ -46,14 +53,41 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
|
||||||
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
|
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
|
||||||
// then squeeze the dimension to reduce the rank
|
// then squeeze the dimension to reduce the rank
|
||||||
let index = &idx_scalar.name;
|
let index = &idx_scalar.name;
|
||||||
|
let output_rank = input_rank - 1;
|
||||||
quote! {
|
quote! {
|
||||||
let #output = #input.select(#dim, Tensor::from_data([#index], &*self.device)).squeeze(#dim);
|
let indices = Tensor::<B, 1, _>::from_data([#index], &*self.device);
|
||||||
|
let slice = Tensor::select(#input, #dim, indices);
|
||||||
|
let #output = slice.squeeze::<#output_rank>(#dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Type::Tensor(idx_tensor) => {
|
Type::Tensor(idx_tensor) => {
|
||||||
let index = scope.tensor_use_owned(idx_tensor, node_position);
|
let index = scope.tensor_use_owned(idx_tensor, node_position);
|
||||||
quote! {
|
let index_rank = idx_tensor.dim;
|
||||||
let #output = #input.select(#dim, #index);
|
let output_rank = index_rank + input_rank - 1;
|
||||||
|
match index_rank {
|
||||||
|
1 => quote! {
|
||||||
|
let indices = #index;
|
||||||
|
let #output = Tensor::select(#input, #dim, indices);
|
||||||
|
},
|
||||||
|
_ => quote! {
|
||||||
|
let indices = #index;
|
||||||
|
|
||||||
|
let n_dims = indices.dims().len();
|
||||||
|
let index_flat = match n_dims {
|
||||||
|
1 => indices.reshape([1, -1]),
|
||||||
|
n if n >= 2 => indices.flatten::<2>(0, n - 2),
|
||||||
|
_ => panic!("Number of dimensions must be greater than 0"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = index_flat
|
||||||
|
.iter_dim(0)
|
||||||
|
.map(|idxs| {
|
||||||
|
let idxs = idxs.squeeze::<1>(0);
|
||||||
|
Tensor::select(#input.clone(), #dim, idxs)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let #output = Tensor::stack::<#output_rank>(out, #dim);
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index),
|
_ => panic!("Gather needs Scalar or Tensor index, got {:?}!", self.index),
|
||||||
|
@ -78,7 +112,7 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_codegen_gather() {
|
fn test_codegen_gather_1d_idx() {
|
||||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||||
|
|
||||||
graph.register(GatherNode::new(
|
graph.register(GatherNode::new(
|
||||||
|
@ -121,8 +155,77 @@ mod tests {
|
||||||
tensor1: Tensor<B, 2>,
|
tensor1: Tensor<B, 2>,
|
||||||
tensor2: Tensor<B, 1, Int>
|
tensor2: Tensor<B, 1, Int>
|
||||||
) -> Tensor<B, 2> {
|
) -> Tensor<B, 2> {
|
||||||
let tensor3 = tensor1.select(0, tensor2);
|
let indices = tensor2;
|
||||||
|
let tensor3 = Tensor::select(tensor1, 0, indices);
|
||||||
|
tensor3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_tokens(graph.codegen(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_codegen_gather_2d_idx() {
|
||||||
|
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||||
|
|
||||||
|
graph.register(GatherNode::new(
|
||||||
|
Type::Tensor(TensorType::new_float("tensor1", 2)),
|
||||||
|
Type::Tensor(TensorType::new_int("tensor2", 2)),
|
||||||
|
TensorType::new_float("tensor3", 3),
|
||||||
|
0,
|
||||||
|
));
|
||||||
|
|
||||||
|
graph.register_input_output(
|
||||||
|
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||||
|
vec!["tensor3".to_string()],
|
||||||
|
);
|
||||||
|
|
||||||
|
let expected = quote! {
|
||||||
|
use burn::tensor::Int;
|
||||||
|
use burn::{
|
||||||
|
module::Module,
|
||||||
|
tensor::{backend::Backend, Tensor},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct Model<B: Backend> {
|
||||||
|
phantom: core::marker::PhantomData<B>,
|
||||||
|
device: burn::module::Ignored<B::Device>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Model <B> {
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
pub fn new(device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
phantom: core::marker::PhantomData,
|
||||||
|
device: burn::module::Ignored(device.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||||
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
tensor1: Tensor<B, 2>,
|
||||||
|
tensor2: Tensor<B, 2, Int>
|
||||||
|
) -> Tensor<B, 3> {
|
||||||
|
let indices = tensor2;
|
||||||
|
|
||||||
|
let n_dims = indices.dims().len();
|
||||||
|
let index_flat = match n_dims {
|
||||||
|
1 => indices.reshape([1, -1]),
|
||||||
|
n if n >= 2 => indices.flatten::<2>(0, n - 2),
|
||||||
|
_ => panic!("Number of dimensions must be greater than 0"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = index_flat
|
||||||
|
.iter_dim(0)
|
||||||
|
.map(|idxs| {
|
||||||
|
let idxs = idxs.squeeze::<1>(0);
|
||||||
|
Tensor::select(tensor1.clone(), 0, idxs)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let tensor3 = Tensor::stack::<3usize>(out, 0);
|
||||||
tensor3
|
tensor3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,7 +241,7 @@ mod tests {
|
||||||
graph.register(GatherNode::new(
|
graph.register(GatherNode::new(
|
||||||
Type::Shape(ShapeType::new("shape1", 3)),
|
Type::Shape(ShapeType::new("shape1", 3)),
|
||||||
Type::Tensor(TensorType::new_int("tensor1", 1)),
|
Type::Tensor(TensorType::new_int("tensor1", 1)),
|
||||||
TensorType::new_float("tensor2", 2),
|
TensorType::new_int("tensor2", 1),
|
||||||
0,
|
0,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -174,8 +277,14 @@ mod tests {
|
||||||
&self,
|
&self,
|
||||||
shape1: [usize; 3],
|
shape1: [usize; 3],
|
||||||
tensor1: Tensor<B, 1, Int>
|
tensor1: Tensor<B, 1, Int>
|
||||||
) -> Tensor<B, 2> {
|
) -> Tensor<B, 1, Int> {
|
||||||
let tensor2 = Tensor::from_data(&shape1 as &[_], &*self.device).select(0, tensor1);
|
let indices = tensor1;
|
||||||
|
|
||||||
|
let tensor2 = Tensor::select(
|
||||||
|
Tensor::<B, 1, Int>::from_data(&shape1 as &[_], &*self.device),
|
||||||
|
0,
|
||||||
|
indices,
|
||||||
|
);
|
||||||
|
|
||||||
tensor2
|
tensor2
|
||||||
}
|
}
|
||||||
|
@ -192,7 +301,7 @@ mod tests {
|
||||||
graph.register(GatherNode::new(
|
graph.register(GatherNode::new(
|
||||||
Type::Tensor(TensorType::new_float("tensor1", 2)),
|
Type::Tensor(TensorType::new_float("tensor1", 2)),
|
||||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
|
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
|
||||||
TensorType::new_float("tensor2", 2),
|
TensorType::new_float("tensor2", 1),
|
||||||
0,
|
0,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -227,8 +336,11 @@ mod tests {
|
||||||
&self,
|
&self,
|
||||||
tensor1: Tensor<B, 2>,
|
tensor1: Tensor<B, 2>,
|
||||||
scalar1: i64
|
scalar1: i64
|
||||||
) -> Tensor<B, 2> {
|
) -> Tensor<B, 1> {
|
||||||
let tensor2 = tensor1.select(0, Tensor::from_data([scalar1], &*self.device)).squeeze(0);
|
let indices = Tensor::<B, 1, _>::from_data([scalar1], &*self.device);
|
||||||
|
|
||||||
|
let slice = Tensor::select(tensor1, 0, indices);
|
||||||
|
let tensor2 = slice.squeeze::<1usize>(0);
|
||||||
|
|
||||||
tensor2
|
tensor2
|
||||||
}
|
}
|
||||||
|
|
|
@ -892,6 +892,7 @@ impl TensorCheck {
|
||||||
|
|
||||||
check
|
check
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn check_prelu_shape<const D: usize>(
|
pub(crate) fn check_prelu_shape<const D: usize>(
|
||||||
shape_tensor: &Shape<D>,
|
shape_tensor: &Shape<D>,
|
||||||
shape_weight: &Shape<1>,
|
shape_weight: &Shape<1>,
|
||||||
|
|
|
@ -816,10 +816,6 @@ fn gather_update_outputs(node: &mut Node) {
|
||||||
_ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty),
|
_ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty),
|
||||||
};
|
};
|
||||||
|
|
||||||
if indices_dim > 1 {
|
|
||||||
panic!("Gather: indices tensor rank above 1 not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
match &node.inputs[0].ty {
|
match &node.inputs[0].ty {
|
||||||
ArgType::Tensor(input_tensor) => {
|
ArgType::Tensor(input_tensor) => {
|
||||||
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
||||||
|
|
Loading…
Reference in New Issue