mirror of https://github.com/tracel-ai/burn.git
Implement ONNX Gather for scalar indices (#2141)
* Implement ONNX Gather for scalars * Fix ONNX gather_scalar codegen test
This commit is contained in:
parent
724bfbc73b
commit
5a0c1dcead
|
@ -33,6 +33,7 @@ fn main() {
|
|||
.input("tests/expand/expand.onnx")
|
||||
.input("tests/flatten/flatten.onnx")
|
||||
.input("tests/gather/gather.onnx")
|
||||
.input("tests/gather/gather_scalar.onnx")
|
||||
.input("tests/gather_elements/gather_elements.onnx")
|
||||
.input("tests/gelu/gelu.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/gather/gather_scalar.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, index):
|
||||
gathered = torch.select(x, 0, index)
|
||||
return gathered
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "gather_scalar.onnx"
|
||||
|
||||
dummy_input = torch.randn(2, 3, device=device)
|
||||
dummy_index = 0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]])
|
||||
test_index = 0
|
||||
|
||||
print("Test input data: {}, {}".format(test_input, test_index))
|
||||
output = model.forward(test_input, test_index)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -42,6 +42,7 @@ include_models!(
|
|||
expand,
|
||||
flatten,
|
||||
gather,
|
||||
gather_scalar,
|
||||
gather_elements,
|
||||
gelu,
|
||||
global_avr_pool,
|
||||
|
@ -458,6 +459,20 @@ mod tests {
|
|||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gather_scalar() {
|
||||
let model: gather_scalar::Model<Backend> = gather_scalar::Model::default();
|
||||
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
|
||||
let index = 0;
|
||||
let output = model.forward(input, index);
|
||||
let expected = TensorData::from([1f32, 2., 3.]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gather_elements() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
|
|
|
@ -7,7 +7,7 @@ use quote::quote;
|
|||
#[derive(Debug, Clone, new)]
|
||||
pub struct GatherNode {
|
||||
pub input: TensorType,
|
||||
pub index: TensorType,
|
||||
pub index: Type,
|
||||
pub output: TensorType,
|
||||
pub dim: usize,
|
||||
}
|
||||
|
@ -18,10 +18,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
|
|||
}
|
||||
|
||||
fn input_types(&self) -> Vec<crate::burn::Type> {
|
||||
vec![
|
||||
Type::Tensor(self.input.clone()),
|
||||
Type::Tensor(self.index.clone()),
|
||||
]
|
||||
vec![Type::Tensor(self.input.clone()), self.index.clone()]
|
||||
}
|
||||
|
||||
fn forward(
|
||||
|
@ -31,11 +28,25 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
|
|||
) -> proc_macro2::TokenStream {
|
||||
let dim = self.dim.to_tokens();
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let index = scope.tensor_use_owned(&self.index, node_position);
|
||||
let output = &self.output.name;
|
||||
|
||||
quote! {
|
||||
let #output = #input.select(#dim, #index);
|
||||
match &self.index {
|
||||
Type::Scalar(idx_scalar) => {
|
||||
// To do a scalar select (select just a single index in one dim),
|
||||
// convert the 0-D index to a 1-D Tensor with len 1 to use burn's select,
|
||||
// then squeeze the dimension to reduce the rank
|
||||
let index = &idx_scalar.name;
|
||||
quote! {
|
||||
let #output = #input.select(#dim, Tensor::from_data([#index], &*self.device)).squeeze(#dim);
|
||||
}
|
||||
}
|
||||
Type::Tensor(idx_tensor) => {
|
||||
let index = scope.tensor_use_owned(idx_tensor, node_position);
|
||||
quote! {
|
||||
let #output = #input.select(#dim, #index);
|
||||
}
|
||||
}
|
||||
_ => panic!("Gather needs Scalar or Tensor index!"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +64,7 @@ mod tests {
|
|||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{gather::GatherNode, test::assert_tokens},
|
||||
TensorType,
|
||||
ScalarKind, ScalarType, TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
@ -62,7 +73,7 @@ mod tests {
|
|||
|
||||
graph.register(GatherNode::new(
|
||||
TensorType::new_float("tensor1", 2),
|
||||
TensorType::new_int("tensor2", 1),
|
||||
Type::Tensor(TensorType::new_int("tensor2", 1)),
|
||||
TensorType::new_float("tensor3", 2),
|
||||
0,
|
||||
));
|
||||
|
@ -109,4 +120,57 @@ mod tests {
|
|||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_gather_scalar() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(GatherNode::new(
|
||||
TensorType::new_float("tensor1", 2),
|
||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
|
||||
TensorType::new_float("tensor2", 2),
|
||||
0,
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "scalar1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 2>,
|
||||
scalar1: i64
|
||||
) -> Tensor<B, 2> {
|
||||
let tensor2 = tensor1.select(0, Tensor::from_data([scalar1], &*self.device)).squeeze(0);
|
||||
|
||||
tensor2
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -617,7 +617,7 @@ impl ParsedOnnxGraph {
|
|||
|
||||
fn gather_conversion(node: Node) -> GatherNode {
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let index = TensorType::from(node.inputs.get(1).unwrap());
|
||||
let index = Type::from(node.inputs.get(1).unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let dim = gather_config(&node);
|
||||
|
||||
|
|
|
@ -815,7 +815,7 @@ fn gather_update_outputs(node: &mut Node) {
|
|||
_ => panic!("Only tensor indices is valid"),
|
||||
};
|
||||
|
||||
if indices_tensor.dim != 1 {
|
||||
if indices_tensor.dim > 1 {
|
||||
panic!("Gather: indices tensor rank above 1 not supported")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue