diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 571cd7b56..b880c2c08 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -108,7 +108,7 @@ represent the corresponding Burn Op. | [MaxPool2d][98] | ✅ | ✅ | | [MaxRoiPool][99] | ❌ | ❌ | | [MaxUnpool][100] | ❌ | ❌ | -| [Mean][101] | ❌ | ✅ | +| [Mean][101] | ✅ | ✅ | | [MeanVarianceNormalization][102] | ❌ | ❌ | | [MelWeightMatrix][103] | ❌ | ❌ | | [Min][104] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 67e2f091a..9ace525b1 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -51,6 +51,7 @@ fn main() { .input("tests/maxpool1d/maxpool1d.onnx") .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/min/min.onnx") + .input("tests/mean/mean.onnx") .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") diff --git a/crates/burn-import/onnx-tests/tests/mean/mean.onnx b/crates/burn-import/onnx-tests/tests/mean/mean.onnx new file mode 100644 index 000000000..39c86e7fc --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/mean/mean.onnx @@ -0,0 +1,23 @@ + + +mean-model: +& +input1 +input2 +input3output"Mean MeanGraphZ +input1 + + +Z +input2 + + +Z +input3 + + +b +output + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/mean/mean.py b/crates/burn-import/onnx-tests/tests/mean/mean.py new file mode 100644 index 000000000..dc5b99cea --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/mean/mean.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/mean/mean.onnx + +import onnx +import onnx.helper +import onnx.checker +import numpy as np + +# Create input tensors +input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3]) +input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3]) +input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3]) + +# Create output tensor +output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3]) + +# Create the Mean node +mean_node = onnx.helper.make_node( + 'Mean', + inputs=['input1', 'input2', 'input3'], + outputs=['output'] +) + +# Create the graph (GraphProto) +graph_def = onnx.helper.make_graph( + nodes=[mean_node], + name='MeanGraph', + inputs=[input1, input2, input3], + outputs=[output] +) + +# Create the model (ModelProto) +model_def = onnx.helper.make_model(graph_def, producer_name='mean-model') +onnx.checker.check_model(model_def) + +# Save the ONNX model +onnx.save(model_def, 'mean.onnx') + +print("ONNX model 'mean.onnx' generated successfully.") + diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 2fc5ec675..31ac73a6f 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -60,6 +60,7 @@ include_models!( maxpool1d, maxpool2d, min, + mean, mul, neg, not, @@ -208,6 +209,21 @@ mod tests { output.to_data().assert_eq(&expected, true); } + #[test] + fn mean_tensor_and_tensor() { + let device = Default::default(); + let model: mean::Model = mean::Model::default(); + + let input1 = Tensor::::from_floats([1., 2., 3., 4.], &device); + let input2 = Tensor::::from_floats([2., 2., 4., 0.], &device); + let input3 = Tensor::::from_floats([3., 2., 5., -4.], &device); + + let output = model.forward(input1, input2, input3); + let expected = TensorData::from([2.0f32, 2., 4., 0.]); + + output.to_data().assert_eq(&expected, true); + } + #[test] fn mul_scalar_with_tensor_and_tensor_with_tensor() { // Initialize the model with weights (loaded from the exported file) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 751cbcb47..46e1d5e1a 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -8,10 +8,10 @@ use super::{ conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode, - random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, - reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, - unary::UnaryNode, unsqueeze::UnsqueezeNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, + prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, + range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, + squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -105,6 +105,7 @@ pub enum Node { Matmul(MatmulNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), + Mean(MeanNode), Pad(PadNode), Range(RangeNode), Reshape(ReshapeNode), @@ -151,6 +152,7 @@ macro_rules! match_all { Node::Matmul(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), + Node::Mean(node) => $func(node), Node::Pad(node) => $func(node), Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), @@ -205,6 +207,7 @@ impl Node { Node::Matmul(_) => "matmul", Node::MaxPool1d(_) => "max_pool1d", Node::MaxPool2d(_) => "max_pool2d", + Node::Mean(_) => "mean", Node::Pad(_) => "pad", Node::Range(_) => "range", Node::Reshape(_) => "reshape", diff --git a/crates/burn-import/src/burn/node/mean.rs b/crates/burn-import/src/burn/node/mean.rs new file mode 100644 index 000000000..17c6b34cb --- /dev/null +++ b/crates/burn-import/src/burn/node/mean.rs @@ -0,0 +1,109 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; + +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct MeanNode { + pub inputs: Vec, + pub output: TensorType, +} + +impl NodeCodegen for MeanNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + self.inputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let inputs = self + .inputs + .iter() + .map(|t| scope.tensor_use_owned(t, node_position)); + + let output = &self.output.name; + let inputs_len = self.inputs.len() as u32; + + quote! { + let #output = (#(#inputs)+*) / #inputs_len; + } + } + + fn into_node(self) -> Node { + Node::Mean(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{mean::MeanNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_mean() { + let mut graph = BurnGraph::::default(); + + graph.register(MeanNode::new( + vec![ + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ], + TensorType::new_float("tensor3", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let tensor3 = (tensor1 + tensor2) / 2u32; + + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 9d1fdce59..875e3e5af 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; +pub(crate) mod mean; pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index f8b0e0ad8..bccea8b5b 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -78,6 +78,7 @@ use onnx_ir::{ }; pub use crate::burn::graph::RecordType; +use crate::burn::node::mean::MeanNode; /// Generate code and states from `.onnx` files and save them to the `out_dir`. #[derive(Debug, Default)] @@ -268,6 +269,7 @@ impl ParsedOnnxGraph { NodeType::Max => graph.register(Self::max_conversion(node)), NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)), NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::Mean => graph.register(Self::mean_conversion(node)), NodeType::PRelu => graph.register(Self::prelu_conversion::(node)), NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)), NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), @@ -972,6 +974,13 @@ impl ParsedOnnxGraph { MaxPool2dNode::new(name, input, output, config) } + fn mean_conversion(node: Node) -> MeanNode { + let inputs = node.inputs.iter().map(TensorType::from).collect(); + let output = TensorType::from(node.outputs.first().unwrap()); + + MeanNode::new(inputs, output) + } + fn prelu_conversion(node: Node) -> PReluNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap());