mirror of https://github.com/tracel-ai/burn.git
Global avg pool (#611)
This commit is contained in:
parent
cb283a9e5b
commit
894783f08d
|
@ -73,7 +73,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
|
|||
- [ ] GatherND
|
||||
- [ ] Gelu
|
||||
- [x] Gemm (Linear Layer)
|
||||
- [ ] GlobalAveragePool
|
||||
- [x] GlobalAveragePool
|
||||
- [ ] GlobalLpPool
|
||||
- [ ] GlobalMaxPool
|
||||
- [ ] Greater
|
||||
|
|
|
@ -13,6 +13,7 @@ fn main() {
|
|||
.input("tests/concat/concat.onnx")
|
||||
.input("tests/conv2d/conv2d.onnx")
|
||||
.input("tests/dropout/dropout.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
pytorch2.0.1:™
|
||||
K
|
||||
onnx::GlobalAveragePool_02/pool1/GlobalAveragePool"GlobalAveragePool
|
||||
7
|
||||
input3/pool2/GlobalAveragePool"GlobalAveragePool torch_jitZ/
|
||||
onnx::GlobalAveragePool_0
|
||||
|
||||
|
||||
|
||||
|
||||
Z
|
||||
input
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
2
|
||||
|
||||
|
||||
|
||||
b
|
||||
3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: global_avr_pool.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.pool1 = nn.AdaptiveAvgPool1d(1)
|
||||
self.pool2 = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x_1d, x_2d):
|
||||
y_1d = self.pool1(x_1d)
|
||||
y_2d = self.pool2(x_2d)
|
||||
return y_1d, y_2d
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
file_name = "global_avr_pool.onnx"
|
||||
input1 = torch.ones(2, 4, 10, device=device)
|
||||
input2 = torch.ones(3, 10, 3, 15, device=device)
|
||||
torch.onnx.export(model, (input1, input2), file_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(file_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data shapes of ones: {}, {}".format(
|
||||
input1.shape, input2.shape))
|
||||
y_1d, y_2d = model.forward(input1, input2)
|
||||
print("Test output data shapes: {}, {}".format(y_1d.shape, y_2d.shape))
|
||||
|
||||
sum1 = y_1d.sum().item()
|
||||
sum2 = y_2d.sum().item()
|
||||
|
||||
print("Test output sums: {}, {}".format(sum1, sum2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -10,7 +10,7 @@ macro_rules! include_models {
|
|||
}
|
||||
|
||||
// ATTENTION: Modify this macro to include all models in the `model` directory.
|
||||
include_models!(add, sub, mul, div, concat, conv2d, dropout);
|
||||
include_models!(add, sub, mul, div, concat, conv2d, dropout, global_avr_pool);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
@ -134,4 +134,30 @@ mod tests {
|
|||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn globalavrpool_1d_2d() {
|
||||
// The model contains 1d and 2d global average pooling nodes
|
||||
let model: global_avr_pool::Model<Backend> = global_avr_pool::Model::default();
|
||||
|
||||
// Run the model with ones as input for easier testing
|
||||
let input_1d = Tensor::<Backend, 3>::ones([2, 4, 10]);
|
||||
let input_2d = Tensor::<Backend, 4>::ones([3, 10, 3, 15]);
|
||||
|
||||
let (output_1d, output_2d) = model.forward(input_1d, input_2d);
|
||||
|
||||
let expected_shape_1d = Shape::from([2, 4, 1]);
|
||||
let expected_shape_2d = Shape::from([3, 10, 1, 1]);
|
||||
assert_eq!(output_1d.shape(), expected_shape_1d);
|
||||
assert_eq!(output_2d.shape(), expected_shape_2d);
|
||||
|
||||
let output_sum_1d = output_1d.sum().into_scalar();
|
||||
let output_sum_2d = output_2d.sum().into_scalar();
|
||||
|
||||
let expected_sum_1d = 8.0; // from pytorch
|
||||
let expected_sum_2d = 30.0; // from pytorch
|
||||
|
||||
assert!(expected_sum_1d.approx_eq(output_sum_1d, (1.0e-4, 2)));
|
||||
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use super::{
|
||||
batch_norm::BatchNormNode, binary::BinaryNode, concat::ConcatNode, constant::ConstantNode,
|
||||
conv2d::Conv2dNode, dropout::DropoutNode, linear::LinearNode, matmul::MatmulNode,
|
||||
max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
|
||||
conv2d::Conv2dNode, dropout::DropoutNode, global_avg_pool::GlobalAvgPoolNode,
|
||||
linear::LinearNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode,
|
||||
unary::UnaryNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
|
@ -82,6 +83,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
Reshape(ReshapeNode),
|
||||
Concat(ConcatNode),
|
||||
Dropout(DropoutNode),
|
||||
GlobalAvgPool(GlobalAvgPoolNode),
|
||||
}
|
||||
|
||||
macro_rules! match_all {
|
||||
|
@ -98,6 +100,7 @@ macro_rules! match_all {
|
|||
Node::Dropout(node) => $func(node),
|
||||
Node::Unary(node) => $func(node),
|
||||
Node::Binary(node) => $func(node),
|
||||
Node::GlobalAvgPool(node) => $func(node),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
@ -123,6 +126,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::BatchNorm(_) => "batch_norm",
|
||||
Node::Reshape(_) => "reshape",
|
||||
Node::Dropout(_) => "dropout",
|
||||
Node::GlobalAvgPool(_) => "global_avg_pool",
|
||||
Node::Unary(unary) => unary.kind.as_str(),
|
||||
Node::Binary(binary) => binary.binary_type.as_str(),
|
||||
}
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
use burn::record::PrecisionSettings;
|
||||
|
||||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
|
||||
|
||||
/// GlobalAvgPoolNode is a node that performs a global average pooling operation.
|
||||
///
|
||||
/// The node is implemented using the AdaptiveAvgPool1d or AdaptiveAvgPool2d module
|
||||
/// depending on the input dimension. AdaptiveAvgPool with output size 1 or size (1,1)
|
||||
/// is equivalent to global average pooling.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GlobalAvgPoolNode {
|
||||
pub field: OtherType,
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
}
|
||||
|
||||
impl GlobalAvgPoolNode {
|
||||
pub fn new<S: AsRef<str>>(name: S, input: TensorType, output: TensorType) -> Self {
|
||||
// Depending on the input dimension, we need to use a different type nn module
|
||||
let field_type = match input.dim {
|
||||
3 => quote! {
|
||||
AdaptiveAvgPool1d
|
||||
},
|
||||
4 => quote! {
|
||||
AdaptiveAvgPool2d
|
||||
},
|
||||
dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"),
|
||||
};
|
||||
|
||||
Self {
|
||||
field: OtherType::new(name, field_type),
|
||||
input,
|
||||
output,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for GlobalAvgPoolNode {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self, _with_record: bool) -> Option<TokenStream> {
|
||||
let name = &self.field.name;
|
||||
|
||||
let tokens = match self.input.dim {
|
||||
3 => {
|
||||
quote! {
|
||||
let #name = AdaptiveAvgPool1dConfig::new(1)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
4 => {
|
||||
quote! {
|
||||
let #name = AdaptiveAvgPool2dConfig::new([1,1])
|
||||
.init();
|
||||
}
|
||||
}
|
||||
dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"),
|
||||
};
|
||||
|
||||
Some(tokens)
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
let field = &self.field.name;
|
||||
|
||||
quote! {
|
||||
let #output = self.#field.forward(#input);
|
||||
}
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
match self.input.dim {
|
||||
3 => {
|
||||
imports.register("burn::nn::pool::AdaptiveAvgPool1d");
|
||||
imports.register("burn::nn::pool::AdaptiveAvgPool1dConfig");
|
||||
}
|
||||
4 => {
|
||||
imports.register("burn::nn::pool::AdaptiveAvgPool2d");
|
||||
imports.register("burn::nn::pool::AdaptiveAvgPool2dConfig");
|
||||
}
|
||||
dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::GlobalAvgPool(self)
|
||||
}
|
||||
|
||||
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
S::serialize_none(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{global_avg_pool::GlobalAvgPoolNode, test::assert_tokens},
|
||||
TensorType,
|
||||
};
|
||||
use burn::record::FullPrecisionSettings;
|
||||
|
||||
#[test]
|
||||
fn test_codegen_2d() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(GlobalAvgPoolNode::new(
|
||||
"global_avg_pool1",
|
||||
TensorType::new_float("input", 4),
|
||||
TensorType::new_float("output", 4),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::pool::AdaptiveAvgPool2d;
|
||||
use burn::nn::pool::AdaptiveAvgPool2dConfig;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend> {
|
||||
global_avg_pool1: AdaptiveAvgPool2d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let global_avg_pool1 = AdaptiveAvgPool2dConfig::new([1, 1])
|
||||
.init();
|
||||
|
||||
Self {
|
||||
global_avg_pool1,
|
||||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.global_avg_pool1.forward(input);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_1d() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(GlobalAvgPoolNode::new(
|
||||
"global_avg_pool1",
|
||||
TensorType::new_float("input", 3),
|
||||
TensorType::new_float("output", 3),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
use burn::nn::pool::AdaptiveAvgPool1d;
|
||||
use burn::nn::pool::AdaptiveAvgPool1dConfig;
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend> {
|
||||
global_avg_pool1: AdaptiveAvgPool1d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let global_avg_pool1 = AdaptiveAvgPool1dConfig::new(1)
|
||||
.init();
|
||||
|
||||
Self {
|
||||
global_avg_pool1,
|
||||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let output = self.global_avg_pool1.forward(input);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ pub(crate) mod concat;
|
|||
pub(crate) mod constant;
|
||||
pub(crate) mod conv2d;
|
||||
pub(crate) mod dropout;
|
||||
pub(crate) mod global_avg_pool;
|
||||
pub(crate) mod linear;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod max_pool2d;
|
||||
|
|
|
@ -90,8 +90,6 @@ pub fn dim_inference(
|
|||
NodeType::Concat => concat_update_outputs(node),
|
||||
NodeType::Reshape => reshape_update_outputs(node),
|
||||
NodeType::Dropout => same_as_input(node),
|
||||
|
||||
//FIXME use correct output for GAP (@antimora 8/1/2023)
|
||||
NodeType::GlobalAveragePool => same_as_input(node),
|
||||
_ => todo!(
|
||||
"shape inference for {:?} is not implemented",
|
||||
|
|
|
@ -566,21 +566,16 @@ fn rename_inputs(
|
|||
counter += 1;
|
||||
}
|
||||
|
||||
let mut counter: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
for node in nodes.iter_mut() {
|
||||
// keep track of the number of nodes of each type
|
||||
counter
|
||||
.entry(node.name.clone())
|
||||
.and_modify(|e| *e += 1)
|
||||
.or_insert(1);
|
||||
let mut counter = 1;
|
||||
|
||||
// loop through node outputs and rename them and store the new name <-> old name mapping
|
||||
for output in node.outputs.iter_mut() {
|
||||
let old_name = output.name.clone();
|
||||
let new_name = format!("{}_out{}", node.name, counter[&node.name]);
|
||||
let new_name = format!("{}_out{}", node.name, counter);
|
||||
output.name = new_name.clone();
|
||||
old_names.insert(old_name, new_name);
|
||||
counter += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -600,7 +595,6 @@ fn rename_inputs(
|
|||
if let Some(new_name) = old_names.get(&output.name) {
|
||||
output.name = new_name.clone();
|
||||
} else {
|
||||
println!("{:#?}", old_names);
|
||||
panic!("Output {} not found in old_names", output.name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ use crate::{
|
|||
constant::{ConstantNode, ConstantValue, TensorValue},
|
||||
conv2d::Conv2dNode,
|
||||
dropout::DropoutNode,
|
||||
global_avg_pool::GlobalAvgPoolNode,
|
||||
linear::LinearNode,
|
||||
matmul::MatmulNode,
|
||||
max_pool2d::MaxPool2dNode,
|
||||
|
@ -194,6 +195,9 @@ impl ONNXGraph {
|
|||
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
||||
NodeType::Cast => graph.register(Self::cast_conversion(node)),
|
||||
NodeType::Dropout => graph.register(Self::dropout_conversion(node)),
|
||||
NodeType::GlobalAveragePool => {
|
||||
graph.register(Self::global_avg_pool_conversion(node))
|
||||
}
|
||||
_ => panic!("Unsupported node conversion {}", node.node_type),
|
||||
}
|
||||
}
|
||||
|
@ -464,6 +468,15 @@ impl ONNXGraph {
|
|||
let name = &node.name;
|
||||
MaxPool2dNode::new(name, input, output, config)
|
||||
}
|
||||
|
||||
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
|
||||
let name = &node.name;
|
||||
|
||||
GlobalAvgPoolNode::new(name, input, output)
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_next_data_serialize<E: Element>(node: &mut Node) -> Option<DataSerialize<E>> {
|
||||
|
|
Loading…
Reference in New Issue