mirror of https://github.com/tracel-ai/burn.git
Add onnx mean (#2119)
* make contacts deterministic across Worlds * add top k acc * add onnx mean * fix * push fix * format --------- Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
This commit is contained in:
parent
cd848b1c94
commit
dad85e0709
|
@ -108,7 +108,7 @@ represent the corresponding Burn Op.
|
|||
| [MaxPool2d][98] | ✅ | ✅ |
|
||||
| [MaxRoiPool][99] | ❌ | ❌ |
|
||||
| [MaxUnpool][100] | ❌ | ❌ |
|
||||
| [Mean][101] | ❌ | ✅ |
|
||||
| [Mean][101] | ✅ | ✅ |
|
||||
| [MeanVarianceNormalization][102] | ❌ | ❌ |
|
||||
| [MelWeightMatrix][103] | ❌ | ❌ |
|
||||
| [Min][104] | ✅ | ✅ |
|
||||
|
|
|
@ -51,6 +51,7 @@ fn main() {
|
|||
.input("tests/maxpool1d/maxpool1d.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/min/min.onnx")
|
||||
.input("tests/mean/mean.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/not/not.onnx")
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
|
||||
mean-model:‹
|
||||
&
|
||||
input1
|
||||
input2
|
||||
input3output"Mean MeanGraphZ
|
||||
input1
|
||||
|
||||
|
||||
Z
|
||||
input2
|
||||
|
||||
|
||||
Z
|
||||
input3
|
||||
|
||||
|
||||
b
|
||||
output
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/mean/mean.onnx
|
||||
|
||||
import onnx
|
||||
import onnx.helper
|
||||
import onnx.checker
|
||||
import numpy as np
|
||||
|
||||
# Create input tensors
|
||||
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
|
||||
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
|
||||
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])
|
||||
|
||||
# Create output tensor
|
||||
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])
|
||||
|
||||
# Create the Mean node
|
||||
mean_node = onnx.helper.make_node(
|
||||
'Mean',
|
||||
inputs=['input1', 'input2', 'input3'],
|
||||
outputs=['output']
|
||||
)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = onnx.helper.make_graph(
|
||||
nodes=[mean_node],
|
||||
name='MeanGraph',
|
||||
inputs=[input1, input2, input3],
|
||||
outputs=[output]
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = onnx.helper.make_model(graph_def, producer_name='mean-model')
|
||||
onnx.checker.check_model(model_def)
|
||||
|
||||
# Save the ONNX model
|
||||
onnx.save(model_def, 'mean.onnx')
|
||||
|
||||
print("ONNX model 'mean.onnx' generated successfully.")
|
||||
|
|
@ -60,6 +60,7 @@ include_models!(
|
|||
maxpool1d,
|
||||
maxpool2d,
|
||||
min,
|
||||
mean,
|
||||
mul,
|
||||
neg,
|
||||
not,
|
||||
|
@ -208,6 +209,21 @@ mod tests {
|
|||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mean_tensor_and_tensor() {
|
||||
let device = Default::default();
|
||||
let model: mean::Model<Backend> = mean::Model::default();
|
||||
|
||||
let input1 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);
|
||||
let input2 = Tensor::<Backend, 1>::from_floats([2., 2., 4., 0.], &device);
|
||||
let input3 = Tensor::<Backend, 1>::from_floats([3., 2., 5., -4.], &device);
|
||||
|
||||
let output = model.forward(input1, input2, input3);
|
||||
let expected = TensorData::from([2.0f32, 2., 4., 0.]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
|
|
|
@ -8,10 +8,10 @@ use super::{
|
|||
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
|
||||
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
|
||||
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
|
||||
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
|
||||
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
|
||||
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
|
||||
unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
|
||||
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
|
||||
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
|
||||
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
use burn::backend::NdArray;
|
||||
|
@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
Matmul(MatmulNode),
|
||||
MaxPool1d(MaxPool1dNode),
|
||||
MaxPool2d(MaxPool2dNode),
|
||||
Mean(MeanNode),
|
||||
Pad(PadNode),
|
||||
Range(RangeNode),
|
||||
Reshape(ReshapeNode),
|
||||
|
@ -151,6 +152,7 @@ macro_rules! match_all {
|
|||
Node::Matmul(node) => $func(node),
|
||||
Node::MaxPool1d(node) => $func(node),
|
||||
Node::MaxPool2d(node) => $func(node),
|
||||
Node::Mean(node) => $func(node),
|
||||
Node::Pad(node) => $func(node),
|
||||
Node::Range(node) => $func(node),
|
||||
Node::Reshape(node) => $func(node),
|
||||
|
@ -205,6 +207,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::Matmul(_) => "matmul",
|
||||
Node::MaxPool1d(_) => "max_pool1d",
|
||||
Node::MaxPool2d(_) => "max_pool2d",
|
||||
Node::Mean(_) => "mean",
|
||||
Node::Pad(_) => "pad",
|
||||
Node::Range(_) => "range",
|
||||
Node::Reshape(_) => "reshape",
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, Type};
|
||||
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct MeanNode {
|
||||
pub inputs: Vec<TensorType>,
|
||||
pub output: TensorType,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for MeanNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|t| Type::Tensor(t.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|t| scope.tensor_use_owned(t, node_position));
|
||||
|
||||
let output = &self.output.name;
|
||||
let inputs_len = self.inputs.len() as u32;
|
||||
|
||||
quote! {
|
||||
let #output = (#(#inputs)+*) / #inputs_len;
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Mean(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::record::FullPrecisionSettings;
|
||||
|
||||
use super::*;
|
||||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{mean::MeanNode, test::assert_tokens},
|
||||
TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_mean() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(MeanNode::new(
|
||||
vec![
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
],
|
||||
TensorType::new_float("tensor3", 4),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 4>,
|
||||
tensor2: Tensor<B, 4>
|
||||
) -> Tensor<B, 4> {
|
||||
let tensor3 = (tensor1 + tensor2) / 2u32;
|
||||
|
||||
tensor3
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -25,6 +25,7 @@ pub(crate) mod mask_where;
|
|||
pub(crate) mod matmul;
|
||||
pub(crate) mod max_pool1d;
|
||||
pub(crate) mod max_pool2d;
|
||||
pub(crate) mod mean;
|
||||
pub(crate) mod pad;
|
||||
pub(crate) mod prelu;
|
||||
pub(crate) mod random_normal;
|
||||
|
|
|
@ -78,6 +78,7 @@ use onnx_ir::{
|
|||
};
|
||||
|
||||
pub use crate::burn::graph::RecordType;
|
||||
use crate::burn::node::mean::MeanNode;
|
||||
|
||||
/// Generate code and states from `.onnx` files and save them to the `out_dir`.
|
||||
#[derive(Debug, Default)]
|
||||
|
@ -268,6 +269,7 @@ impl ParsedOnnxGraph {
|
|||
NodeType::Max => graph.register(Self::max_conversion(node)),
|
||||
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
|
||||
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
|
||||
NodeType::Mean => graph.register(Self::mean_conversion(node)),
|
||||
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
|
||||
NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)),
|
||||
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
|
||||
|
@ -972,6 +974,13 @@ impl ParsedOnnxGraph {
|
|||
MaxPool2dNode::new(name, input, output, config)
|
||||
}
|
||||
|
||||
fn mean_conversion(node: Node) -> MeanNode {
|
||||
let inputs = node.inputs.iter().map(TensorType::from).collect();
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
MeanNode::new(inputs, output)
|
||||
}
|
||||
|
||||
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
|
Loading…
Reference in New Issue