mirror of https://github.com/tracel-ai/burn.git
Added ONNX AvgPool1d (#1744)
This commit is contained in:
parent
a6e3b4e81e
commit
5bbc5ea944
|
@ -8,7 +8,7 @@ use crate::tensor::Tensor;
|
|||
use burn_tensor::module::avg_pool1d;
|
||||
|
||||
/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
|
||||
#[derive(Config)]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct AvgPool1dConfig {
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: usize,
|
||||
|
@ -20,7 +20,7 @@ pub struct AvgPool1dConfig {
|
|||
pub padding: PaddingConfig1d,
|
||||
/// If the padding is counted in the denominator when computing the average.
|
||||
#[config(default = "true")]
|
||||
count_include_pad: bool,
|
||||
pub count_include_pad: bool,
|
||||
}
|
||||
|
||||
/// Applies a 1D avg pooling over input tensors.
|
||||
|
|
|
@ -17,7 +17,7 @@ represent the corresponding Burn Op.
|
|||
| [Asinh][9] | ❌ | ❌ |
|
||||
| [Atan][10] | ❌ | ❌ |
|
||||
| [Atanh][11] | ❌ | ❌ |
|
||||
| [AveragePool1d][12] | ❌ | ✅ |
|
||||
| [AveragePool1d][12] | ✅ | ✅ |
|
||||
| [AveragePool2d][12] | ✅ | ✅ |
|
||||
| [BatchNormalization][14] | ✅ | ✅ |
|
||||
| [Bernoulli][15] | ❌ | ❌ |
|
||||
|
|
|
@ -8,6 +8,7 @@ fn main() {
|
|||
ModelGen::new()
|
||||
.input("tests/add/add_int.onnx")
|
||||
.input("tests/add/add.onnx")
|
||||
.input("tests/avg_pool1d/avg_pool1d.onnx")
|
||||
.input("tests/avg_pool2d/avg_pool2d.onnx")
|
||||
.input("tests/batch_norm/batch_norm.onnx")
|
||||
.input("tests/cast/cast.onnx")
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: avg_pool1d.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.pool1 = nn.AvgPool1d(4, stride=2)
|
||||
|
||||
self.pool2 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=True)
|
||||
|
||||
self.pool3 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=False)
|
||||
|
||||
def forward(self, x1, x2, x3):
|
||||
y1 = self.pool1(x1)
|
||||
y2 = self.pool2(x2)
|
||||
y3 = self.pool3(x3)
|
||||
return y1, y2, y3
|
||||
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(1)
|
||||
|
||||
# Print options
|
||||
torch.set_printoptions(precision=3)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
file_name = "avg_pool1d.onnx"
|
||||
input1 = torch.randn(1, 5, 5, device=device)
|
||||
torch.onnx.export(model, (input1, input1, input1), file_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(file_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data shape: {}".format(input1.shape))
|
||||
print("Test input data: {}".format(input1))
|
||||
output1, output2, output3 = model.forward(input1, input1, input1)
|
||||
print("Test output1 data shape: {}".format(output1.shape))
|
||||
print("Test output2 data shape: {}".format(output2.shape))
|
||||
print("Test output3 data shape: {}".format(output3.shape))
|
||||
print("Test output1: {}".format(output1))
|
||||
print("Test output2: {}".format(output2))
|
||||
print("Test output3: {}".format(output3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -18,6 +18,7 @@ include_models!(
|
|||
add_int,
|
||||
add,
|
||||
avg_pool2d,
|
||||
avg_pool1d,
|
||||
batch_norm,
|
||||
cast,
|
||||
clip_opset16,
|
||||
|
@ -498,6 +499,53 @@ mod tests {
|
|||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool1d() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: avg_pool1d::Model<Backend> = avg_pool1d::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 3>::from_floats(
|
||||
[[
|
||||
[-1.526, -0.750, -0.654, -1.609, -0.100],
|
||||
[-0.609, -0.980, -1.609, -0.712, 1.171],
|
||||
[1.767, -0.095, 0.139, -1.579, -0.321],
|
||||
[-0.299, 1.879, 0.336, 0.275, 1.716],
|
||||
[-0.056, 0.911, -1.392, 2.689, -0.111],
|
||||
]],
|
||||
&device,
|
||||
);
|
||||
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
|
||||
let expected1 = Data::from([[[-1.135], [-0.978], [0.058], [0.548], [0.538]]]);
|
||||
let expected2 = Data::from([[
|
||||
[-0.569, -1.135, -0.591],
|
||||
[-0.397, -0.978, -0.288],
|
||||
[0.418, 0.058, -0.440],
|
||||
[0.395, 0.548, 0.582],
|
||||
[0.214, 0.538, 0.296],
|
||||
]]);
|
||||
let expected3 = Data::from([[
|
||||
[-1.138, -1.135, -0.788],
|
||||
[-0.794, -0.978, -0.383],
|
||||
[0.836, 0.058, -0.587],
|
||||
[0.790, 0.548, 0.776],
|
||||
[0.427, 0.538, 0.395],
|
||||
]]);
|
||||
|
||||
let expected_shape1 = Shape::from([1, 5, 1]);
|
||||
let expected_shape2 = Shape::from([1, 5, 3]);
|
||||
let expected_shape3 = Shape::from([1, 5, 3]);
|
||||
|
||||
assert_eq!(output1.shape(), expected_shape1);
|
||||
assert_eq!(output2.shape(), expected_shape2);
|
||||
assert_eq!(output3.shape(), expected_shape3);
|
||||
|
||||
output1.to_data().assert_approx_eq(&expected1, 3);
|
||||
output2.to_data().assert_approx_eq(&expected2, 3);
|
||||
output3.to_data().assert_approx_eq(&expected3, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings};
|
||||
|
||||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AvgPool1dNode {
|
||||
pub field: OtherType,
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub config: AvgPool1dConfig,
|
||||
}
|
||||
|
||||
impl AvgPool1dNode {
|
||||
pub fn new<S: AsRef<str>>(
|
||||
name: S,
|
||||
input: TensorType,
|
||||
output: TensorType,
|
||||
config: AvgPool1dConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
field: OtherType::new(
|
||||
name,
|
||||
quote! {
|
||||
AvgPool1d
|
||||
},
|
||||
),
|
||||
input,
|
||||
output,
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool1dNode {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self) -> Option<TokenStream> {
|
||||
let name = &self.field.name;
|
||||
let kernel_size = self.config.kernel_size.to_tokens();
|
||||
let strides = self.config.stride.to_tokens();
|
||||
let padding = self.config.padding.to_tokens();
|
||||
let count_include_pad = self.config.count_include_pad;
|
||||
|
||||
let tokens = quote! {
|
||||
let #name = AvgPool1dConfig::new(#kernel_size)
|
||||
.with_stride(#strides)
|
||||
.with_padding(#padding)
|
||||
.with_count_include_pad(#count_include_pad)
|
||||
.init();
|
||||
};
|
||||
|
||||
Some(tokens)
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
let field = &self.field.name;
|
||||
|
||||
quote! {
|
||||
let #output = self.#field.forward(#input);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
imports.register("burn::nn::PaddingConfig1d");
|
||||
imports.register("burn::nn::pool::AvgPool1d");
|
||||
imports.register("burn::nn::pool::AvgPool1dConfig");
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::AvgPool1d(self)
|
||||
}
|
||||
|
||||
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
S::serialize_none(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
|
||||
use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings};
|
||||
|
||||
#[test]
|
||||
fn test_codegen() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(AvgPool1dNode::new(
|
||||
"avg_pool1d",
|
||||
TensorType::new_float("input", 3),
|
||||
TensorType::new_float("output", 3),
|
||||
AvgPool1dConfig::new(3)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Valid),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::PaddingConfig1d;
|
||||
use burn::nn::pool::AvgPool1d;
|
||||
use burn::nn::pool::AvgPool1dConfig;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend> {
|
||||
avg_pool1d: AvgPool1d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
let avg_pool1d = AvgPool1dConfig::new(3)
|
||||
.with_stride(1)
|
||||
.with_padding(PaddingConfig1d::Valid)
|
||||
.with_count_include_pad(true)
|
||||
.init();
|
||||
|
||||
Self {
|
||||
avg_pool1d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let output = self.avg_pool1d.forward(input);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -1,11 +1,11 @@
|
|||
use super::{
|
||||
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
|
||||
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
|
||||
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
|
||||
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
|
||||
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
|
||||
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode,
|
||||
unsqueeze::UnsqueezeNode,
|
||||
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
|
||||
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
|
||||
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
|
||||
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
|
||||
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
|
||||
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
|
||||
unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
use burn::backend::NdArray;
|
||||
|
@ -75,6 +75,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Node<PS: PrecisionSettings> {
|
||||
AvgPool1d(AvgPool1dNode),
|
||||
AvgPool2d(AvgPool2dNode),
|
||||
BatchNorm(BatchNormNode<PS>),
|
||||
Binary(BinaryNode),
|
||||
|
@ -103,6 +104,7 @@ macro_rules! match_all {
|
|||
($self:expr, $func:expr) => {{
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
match $self {
|
||||
Node::AvgPool1d(node) => $func(node),
|
||||
Node::AvgPool2d(node) => $func(node),
|
||||
Node::BatchNorm(node) => $func(node),
|
||||
Node::Binary(node) => $func(node),
|
||||
|
@ -141,6 +143,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
|
|||
impl<PS: PrecisionSettings> Node<PS> {
|
||||
pub fn name(&self) -> &str {
|
||||
match self {
|
||||
Node::AvgPool1d(_) => "avg_pool1d",
|
||||
Node::AvgPool2d(_) => "avg_pool2d",
|
||||
Node::BatchNorm(_) => "batch_norm",
|
||||
Node::Binary(binary) => binary.binary_type.as_str(),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
mod base;
|
||||
|
||||
pub(crate) mod avg_pool1d;
|
||||
pub(crate) mod avg_pool2d;
|
||||
pub(crate) mod batch_norm;
|
||||
pub(crate) mod binary;
|
||||
|
|
|
@ -14,6 +14,7 @@ use super::{
|
|||
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
||||
match node.node_type {
|
||||
NodeType::Add => same_as_input(node),
|
||||
NodeType::AveragePool1d => same_as_input(node),
|
||||
NodeType::AveragePool2d => same_as_input(node),
|
||||
NodeType::BatchNormalization => same_as_input(node),
|
||||
NodeType::Cast => cast_update_outputs(node),
|
||||
|
@ -38,6 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::Log => same_as_input(node),
|
||||
NodeType::LogSoftmax => same_as_input(node),
|
||||
NodeType::MatMul => matmul_update_outputs(node),
|
||||
NodeType::MaxPool1d => same_as_input(node),
|
||||
NodeType::MaxPool2d => same_as_input(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Neg => same_as_input(node),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn::nn::{
|
||||
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
|
||||
pool::{AvgPool2dConfig, MaxPool1dConfig, MaxPool2dConfig},
|
||||
pool::{AvgPool1dConfig, AvgPool2dConfig, MaxPool1dConfig, MaxPool2dConfig},
|
||||
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
|
||||
PaddingConfig2d,
|
||||
};
|
||||
|
@ -200,6 +200,37 @@ pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
|
|||
.with_bias(bias)
|
||||
}
|
||||
|
||||
pub fn avg_pool1d_config(curr: &Node) -> AvgPool1dConfig {
|
||||
let mut kernel_shape = Vec::new();
|
||||
let mut strides = vec![1];
|
||||
let mut pads = vec![0, 0];
|
||||
let mut count_include_pad: i64 = 0;
|
||||
let mut ceil_mode: i64 = 0;
|
||||
|
||||
for (key, value) in curr.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"kernel_shape" => kernel_shape = value.clone().into_i64s(),
|
||||
"strides" => strides = value.clone().into_i64s(),
|
||||
"pads" => pads = value.clone().into_i64s(),
|
||||
"count_include_pad" => count_include_pad = value.clone().into_i64(),
|
||||
"ceil_mode" => ceil_mode = value.clone().into_i64(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
assert_eq!(kernel_shape.len(), 1);
|
||||
assert_eq!(strides.len(), 1);
|
||||
|
||||
if ceil_mode == 1 {
|
||||
panic!("ceil_mode is not supported");
|
||||
}
|
||||
|
||||
let padding = padding_config_1d(&pads);
|
||||
|
||||
AvgPool1dConfig::new(kernel_shape[0] as usize)
|
||||
.with_stride(strides[0] as usize)
|
||||
.with_padding(padding)
|
||||
.with_count_include_pad(count_include_pad == 1)
|
||||
}
|
||||
/// Create a AvgPool2dConfig from the attributes of the node
|
||||
pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
|
||||
let mut kernel_shape = Vec::new();
|
||||
|
|
|
@ -14,6 +14,7 @@ use crate::{
|
|||
burn::{
|
||||
graph::BurnGraph,
|
||||
node::{
|
||||
avg_pool1d::AvgPool1dNode,
|
||||
avg_pool2d::AvgPool2dNode,
|
||||
batch_norm::BatchNormNode,
|
||||
binary::BinaryNode,
|
||||
|
@ -243,6 +244,7 @@ impl OnnxGraph {
|
|||
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
|
||||
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
|
||||
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
|
||||
NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)),
|
||||
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
|
||||
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
|
||||
NodeType::Neg => graph.register(Self::neg_conversion(node)),
|
||||
|
@ -746,6 +748,14 @@ impl OnnxGraph {
|
|||
let name = &node.name;
|
||||
ConvTranspose2dNode::<PS>::new(name, input, output, weight, bias, config)
|
||||
}
|
||||
fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let config = avg_pool1d_config(&node);
|
||||
|
||||
let name = &node.name;
|
||||
AvgPool1dNode::new(name, input, output, config)
|
||||
}
|
||||
|
||||
fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
|
|
Loading…
Reference in New Issue