Add onnx mean (#2119)

* make contacts deterministic across Worlds

* add top k acc

* add onnx mean

* fix

* push fix

* format

---------

Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
This commit is contained in:
Periwink 2024-08-07 13:03:59 -04:00 committed by GitHub
parent cd848b1c94
commit dad85e0709
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 208 additions and 5 deletions

View File

@ -108,7 +108,7 @@ represent the corresponding Burn Op.
| [MaxPool2d][98] | ✅ | ✅ |
| [MaxRoiPool][99] | ❌ | ❌ |
| [MaxUnpool][100] | ❌ | ❌ |
| [Mean][101] | | ✅ |
| [Mean][101] | | ✅ |
| [MeanVarianceNormalization][102] | ❌ | ❌ |
| [MelWeightMatrix][103] | ❌ | ❌ |
| [Min][104] | ✅ | ✅ |

View File

@ -51,6 +51,7 @@ fn main() {
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/min/min.onnx")
.input("tests/mean/mean.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")

View File

@ -0,0 +1,23 @@


mean-model:
&
input1
input2
input3output"Mean MeanGraphZ
input1

Z
input2

Z
input3

b
output

B

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/mean/mean.onnx
import onnx
import onnx.helper
import onnx.checker
import numpy as np
# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])
# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])
# Create the Mean node
mean_node = onnx.helper.make_node(
'Mean',
inputs=['input1', 'input2', 'input3'],
outputs=['output']
)
# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[mean_node],
name='MeanGraph',
inputs=[input1, input2, input3],
outputs=[output]
)
# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='mean-model')
onnx.checker.check_model(model_def)
# Save the ONNX model
onnx.save(model_def, 'mean.onnx')
print("ONNX model 'mean.onnx' generated successfully.")

View File

@ -60,6 +60,7 @@ include_models!(
maxpool1d,
maxpool2d,
min,
mean,
mul,
neg,
not,
@ -208,6 +209,21 @@ mod tests {
output.to_data().assert_eq(&expected, true);
}
#[test]
fn mean_tensor_and_tensor() {
let device = Default::default();
let model: mean::Model<Backend> = mean::Model::default();
let input1 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);
let input2 = Tensor::<Backend, 1>::from_floats([2., 2., 4., 0.], &device);
let input3 = Tensor::<Backend, 1>::from_floats([3., 2., 5., -4.], &device);
let output = model.forward(input1, input2, input3);
let expected = TensorData::from([2.0f32, 2., 4., 0.]);
output.to_data().assert_eq(&expected, true);
}
#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)

View File

@ -8,10 +8,10 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Mean(MeanNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
@ -151,6 +152,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Mean(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
@ -205,6 +207,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Mean(_) => "mean",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",

View File

@ -0,0 +1,109 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug, Clone, new)]
pub struct MeanNode {
pub inputs: Vec<TensorType>,
pub output: TensorType,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for MeanNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
self.inputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let inputs = self
.inputs
.iter()
.map(|t| scope.tensor_use_owned(t, node_position));
let output = &self.output.name;
let inputs_len = self.inputs.len() as u32;
quote! {
let #output = (#(#inputs)+*) / #inputs_len;
}
}
fn into_node(self) -> Node<PS> {
Node::Mean(self)
}
}
#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{mean::MeanNode, test::assert_tokens},
TensorType,
};
#[test]
fn test_codegen_mean() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(MeanNode::new(
vec![
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
],
TensorType::new_float("tensor3", 4),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = (tensor1 + tensor2) / 2u32;
tensor3
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod mean;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;

View File

@ -78,6 +78,7 @@ use onnx_ir::{
};
pub use crate::burn::graph::RecordType;
use crate::burn::node::mean::MeanNode;
/// Generate code and states from `.onnx` files and save them to the `out_dir`.
#[derive(Debug, Default)]
@ -268,6 +269,7 @@ impl ParsedOnnxGraph {
NodeType::Max => graph.register(Self::max_conversion(node)),
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::Mean => graph.register(Self::mean_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
@ -972,6 +974,13 @@ impl ParsedOnnxGraph {
MaxPool2dNode::new(name, input, output, config)
}
fn mean_conversion(node: Node) -> MeanNode {
let inputs = node.inputs.iter().map(TensorType::from).collect();
let output = TensorType::from(node.outputs.first().unwrap());
MeanNode::new(inputs, output)
}
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());