Implement ONNX Gather for scalar indices (#2141)

* Implement ONNX Gather for scalars

* Fix ONNX gather_scalar codegen test
This commit is contained in:
Adrian Müller 2024-08-09 17:53:01 +02:00 committed by GitHub
parent 724bfbc73b
commit 5a0c1dcead
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 139 additions and 12 deletions

View File

@ -33,6 +33,7 @@ fn main() {
.input("tests/expand/expand.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gather/gather_scalar.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/gather/gather_scalar.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, index):
gathered = torch.select(x, 0, index)
return gathered
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gather_scalar.onnx"
dummy_input = torch.randn(2, 3, device=device)
dummy_index = 0
torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = 0
print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -42,6 +42,7 @@ include_models!(
expand,
flatten,
gather,
gather_scalar,
gather_elements,
gelu,
global_avr_pool,
@ -458,6 +459,20 @@ mod tests {
assert_eq!(output.to_data(), expected);
}
#[test]
fn gather_scalar() {
let model: gather_scalar::Model<Backend> = gather_scalar::Model::default();
let device = Default::default();
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let index = 0;
let output = model.forward(input, index);
let expected = TensorData::from([1f32, 2., 3.]);
assert_eq!(output.to_data(), expected);
}
#[test]
fn gather_elements() {
// Initialize the model with weights (loaded from the exported file)

View File

@ -7,7 +7,7 @@ use quote::quote;
#[derive(Debug, Clone, new)]
pub struct GatherNode {
pub input: TensorType,
pub index: TensorType,
pub index: Type,
pub output: TensorType,
pub dim: usize,
}
@ -18,10 +18,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
}
fn input_types(&self) -> Vec<crate::burn::Type> {
vec![
Type::Tensor(self.input.clone()),
Type::Tensor(self.index.clone()),
]
vec![Type::Tensor(self.input.clone()), self.index.clone()]
}
fn forward(
@ -31,11 +28,25 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
) -> proc_macro2::TokenStream {
let dim = self.dim.to_tokens();
let input = scope.tensor_use_owned(&self.input, node_position);
let index = scope.tensor_use_owned(&self.index, node_position);
let output = &self.output.name;
quote! {
let #output = #input.select(#dim, #index);
match &self.index {
Type::Scalar(idx_scalar) => {
// To do a scalar select (select just a single index in one dim),
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
// then squeeze the dimension to reduce the rank
let index = &idx_scalar.name;
quote! {
let #output = #input.select(#dim, Tensor::from_data([#index], &*self.device)).squeeze(#dim);
}
}
Type::Tensor(idx_tensor) => {
let index = scope.tensor_use_owned(idx_tensor, node_position);
quote! {
let #output = #input.select(#dim, #index);
}
}
_ => panic!("Gather needs Scalar or Tensor index!"),
}
}
@ -53,7 +64,7 @@ mod tests {
use crate::burn::{
graph::BurnGraph,
node::{gather::GatherNode, test::assert_tokens},
TensorType,
ScalarKind, ScalarType, TensorType,
};
#[test]
@ -62,7 +73,7 @@ mod tests {
graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 1),
Type::Tensor(TensorType::new_int("tensor2", 1)),
TensorType::new_float("tensor3", 2),
0,
));
@ -109,4 +120,57 @@ mod tests {
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_gather_scalar() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
TensorType::new_float("tensor2", 2),
0,
));
graph.register_input_output(
vec!["tensor1".to_string(), "scalar1".to_string()],
vec!["tensor2".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
scalar1: i64
) -> Tensor<B, 2> {
let tensor2 = tensor1.select(0, Tensor::from_data([scalar1], &*self.device)).squeeze(0);
tensor2
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -617,7 +617,7 @@ impl ParsedOnnxGraph {
fn gather_conversion(node: Node) -> GatherNode {
let input = TensorType::from(node.inputs.first().unwrap());
let index = TensorType::from(node.inputs.get(1).unwrap());
let index = Type::from(node.inputs.get(1).unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let dim = gather_config(&node);

View File

@ -815,7 +815,7 @@ fn gather_update_outputs(node: &mut Node) {
_ => panic!("Only tensor indices is valid"),
};
if indices_tensor.dim != 1 {
if indices_tensor.dim > 1 {
panic!("Gather: indices tensor rank above 1 not supported")
}