From d770b1f470d734afd79253d082ea9d084a8d5431 Mon Sep 17 00:00:00 2001 From: mepatrick73 <114622680+mepatrick73@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:43:59 -0400 Subject: [PATCH] ONNX Tile operation (#2092) * renaming repeat to repeat_dim * implementing repeat function * renaming repeat files to repeat_dim * renaming part 2 * renaming part 3 * renaming part 4 * renaming part 5 * adding test file * adding unit test * adding rust book documentation * adding function args doc * fixing tests * changing repeat api to match pytorch equivalent * fixing clippy error * implementing tile onnx file * temp * working implementation and test * working e2e test * adding new supported onnx operation to the md file --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 18 ++++ .../onnx-tests/tests/tile/tile.onnx | Bin 0 -> 163 bytes .../burn-import/onnx-tests/tests/tile/tile.py | 67 ++++++++++++ crates/burn-import/src/burn/node/base.rs | 5 +- crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/tile.rs | 97 ++++++++++++++++++ .../burn-import/src/onnx/op_configuration.rs | 22 +++- crates/burn-import/src/onnx/to_burn.rs | 13 ++- 10 files changed, 222 insertions(+), 4 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/tile/tile.onnx create mode 100644 crates/burn-import/onnx-tests/tests/tile/tile.py create mode 100644 crates/burn-import/src/burn/node/tile.rs diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index b5cfd5519..a9517c04c 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -191,7 +191,7 @@ represent the corresponding Burn Op. | [Tanh][182] | ✅ | ✅ | | [TfIdfVectorizer][183] | ❌ | ❌ | | [ThresholdedRelu][184] | ❌ | ❌ | -| [Tile][185] | ❌ | ✅ | +| [Tile][185] | ✅ | ✅ | | [TopK][186] | ❌ | ✅ | | [Transpose][187] | ✅ | ✅ | | [Trilu][188] | ❌ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a41f3df3c..457019318 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -93,6 +93,7 @@ fn main() { .input("tests/sum/sum.onnx") .input("tests/sum/sum_int.onnx") .input("tests/tanh/tanh.onnx") + .input("tests/tile/tile.onnx") .input("tests/transpose/transpose.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index a108aaf4f..6bdaaaab1 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -102,6 +102,7 @@ include_models!( sum, sum_int, tanh, + tile, transpose, unsqueeze, unsqueeze_opset11, @@ -1712,6 +1713,23 @@ mod tests { output.to_data().assert_eq(&expected, true); } + #[test] + fn tile() { + let device = Default::default(); + let model: tile::Model = tile::Model::new(&device); + + let input = Tensor::::from_floats([[1., 2.], [3., 4.]], &device); + let output = model.forward(input).to_data(); + let expected = TensorData::from([ + [1.0f32, 2.0f32, 1.0f32, 2.0f32], + [3.0f32, 4.0f32, 3.0f32, 4.0f32], + [1.0f32, 2.0f32, 1.0f32, 2.0f32], + [3.0f32, 4.0f32, 3.0f32, 4.0f32], + ]); + + output.assert_eq(&expected, true); + } + #[test] fn unsqueeze() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/tile/tile.onnx b/crates/burn-import/onnx-tests/tests/tile/tile.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a2162746ec601be7b6a1968a7fb0e177c3e71c8c GIT binary patch literal 163 zcmd { Slice(SliceNode), Squeeze(SqueezeNode), Sum(SumNode), + Tile(TileNode), Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), @@ -160,6 +161,7 @@ macro_rules! match_all { Node::Slice(node) => $func(node), Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), + Node::Tile(node) => $func(node), Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), @@ -215,6 +217,7 @@ impl Node { Node::Slice(_) => "slice", Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", + Node::Tile(_) => "tile", Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 875e3e5af..ee294ddfd 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -36,6 +36,7 @@ pub(crate) mod resize; pub(crate) mod slice; pub(crate) mod squeeze; pub(crate) mod sum; +pub(crate) mod tile; pub(crate) mod unary; pub(crate) mod unsqueeze; pub(crate) use base::*; diff --git a/crates/burn-import/src/burn/node/tile.rs b/crates/burn-import/src/burn/node/tile.rs new file mode 100644 index 000000000..cf56f7ad2 --- /dev/null +++ b/crates/burn-import/src/burn/node/tile.rs @@ -0,0 +1,97 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Config, Debug)] +pub struct TileConfig { + pub repeats: Vec, +} + +#[derive(Debug, Clone, new)] +pub struct TileNode { + pub input: TensorType, + pub output: TensorType, + pub config: TileConfig, +} + +impl NodeCodegen for TileNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + + let repeats = self.config.repeats.iter().map(|r| r.to_tokens()); + + quote! { + let #output = #input.repeat(&[#(#repeats),*]); + } + } + + fn into_node(self) -> Node { + Node::Tile(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{test::assert_tokens, tile::TileConfig, tile::TileNode}, + TensorType, + }; + + #[test] + fn test_codegen_tile() { + let mut graph = BurnGraph::::default(); + let config = TileConfig::new(vec![2, 3, 4]); + graph.register(TileNode::new( + TensorType::new_float("input", 3), + TensorType::new_float("output", 3), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.repeat(&[2, 3, 4]); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 586493ebe..d701150c8 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::pad::PadConfig; +use crate::burn::node::{pad::PadConfig, tile::TileConfig}; use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -745,6 +745,26 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) { ) } +/// Create a TileConfig from the attributes of the node +pub fn tile_config(node: &Node) -> TileConfig { + let repeat = node + .inputs + .get(1) + .map(|input| { + if let Some(data) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + TileConfig::new(repeat) +} + /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads(node: &Node) -> Vec { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 39e9a428c..18d2981e8 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -50,6 +50,7 @@ use crate::{ slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, + tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, @@ -66,7 +67,8 @@ use super::op_configuration::{ hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, - shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config, + shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config, + unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -335,6 +337,7 @@ impl ParsedOnnxGraph { NodeType::Sign => graph.register(Self::sign_conversion(node)), NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)), + NodeType::Tile => graph.register(Self::tile_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) @@ -1167,6 +1170,14 @@ impl ParsedOnnxGraph { SqueezeNode::new(input, output, axes) } + + fn tile_conversion(node: Node) -> TileNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let config = tile_config(&node); + + TileNode::new(input, output, config) + } } /// Extract data from node states and convert it to `TensorData`.