Global avg pool (#611)

This commit is contained in:
Dilshod Tadjibaev 2023-08-09 15:15:33 -05:00 committed by GitHub
parent cb283a9e5b
commit 894783f08d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 346 additions and 15 deletions

View File

@ -73,7 +73,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
- [ ] GatherND
- [ ] Gelu
- [x] Gemm (Linear Layer)
- [ ] GlobalAveragePool
- [x] GlobalAveragePool
- [ ] GlobalLpPool
- [ ] GlobalMaxPool
- [ ] Greater

View File

@ -13,6 +13,7 @@ fn main() {
.input("tests/concat/concat.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/dropout/dropout.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.out_dir("model/")
.run_from_script();

View File

@ -0,0 +1,30 @@
pytorch2.0.1:™
K
onnx::GlobalAveragePool_02/pool1/GlobalAveragePool"GlobalAveragePool
7
input3/pool2/GlobalAveragePool"GlobalAveragePool torch_jitZ/
onnx::GlobalAveragePool_0




Z
input




b
2



b
3




B

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# used to generate model: global_avr_pool.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.pool1 = nn.AdaptiveAvgPool1d(1)
self.pool2 = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x_1d, x_2d):
y_1d = self.pool1(x_1d)
y_2d = self.pool2(x_2d)
return y_1d, y_2d
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
file_name = "global_avr_pool.onnx"
input1 = torch.ones(2, 4, 10, device=device)
input2 = torch.ones(3, 10, 3, 15, device=device)
torch.onnx.export(model, (input1, input2), file_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(file_name))
# Output some test data for use in the test
print("Test input data shapes of ones: {}, {}".format(
input1.shape, input2.shape))
y_1d, y_2d = model.forward(input1, input2)
print("Test output data shapes: {}, {}".format(y_1d.shape, y_2d.shape))
sum1 = y_1d.sum().item()
sum2 = y_2d.sum().item()
print("Test output sums: {}, {}".format(sum1, sum2))
if __name__ == '__main__':
main()

View File

@ -10,7 +10,7 @@ macro_rules! include_models {
}
// ATTENTION: Modify this macro to include all models in the `model` directory.
include_models!(add, sub, mul, div, concat, conv2d, dropout);
include_models!(add, sub, mul, div, concat, conv2d, dropout, global_avr_pool);
#[cfg(test)]
mod tests {
@ -134,4 +134,30 @@ mod tests {
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
#[test]
fn globalavrpool_1d_2d() {
// The model contains 1d and 2d global average pooling nodes
let model: global_avr_pool::Model<Backend> = global_avr_pool::Model::default();
// Run the model with ones as input for easier testing
let input_1d = Tensor::<Backend, 3>::ones([2, 4, 10]);
let input_2d = Tensor::<Backend, 4>::ones([3, 10, 3, 15]);
let (output_1d, output_2d) = model.forward(input_1d, input_2d);
let expected_shape_1d = Shape::from([2, 4, 1]);
let expected_shape_2d = Shape::from([3, 10, 1, 1]);
assert_eq!(output_1d.shape(), expected_shape_1d);
assert_eq!(output_2d.shape(), expected_shape_2d);
let output_sum_1d = output_1d.sum().into_scalar();
let output_sum_2d = output_2d.sum().into_scalar();
let expected_sum_1d = 8.0; // from pytorch
let expected_sum_2d = 30.0; // from pytorch
assert!(expected_sum_1d.approx_eq(output_sum_1d, (1.0e-4, 2)));
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
}
}

View File

@ -1,7 +1,8 @@
use super::{
batch_norm::BatchNormNode, binary::BinaryNode, concat::ConcatNode, constant::ConstantNode,
conv2d::Conv2dNode, dropout::DropoutNode, linear::LinearNode, matmul::MatmulNode,
max_pool2d::MaxPool2dNode, reshape::ReshapeNode, unary::UnaryNode,
conv2d::Conv2dNode, dropout::DropoutNode, global_avg_pool::GlobalAvgPoolNode,
linear::LinearNode, matmul::MatmulNode, max_pool2d::MaxPool2dNode, reshape::ReshapeNode,
unary::UnaryNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::record::PrecisionSettings;
@ -82,6 +83,7 @@ pub enum Node<PS: PrecisionSettings> {
Reshape(ReshapeNode),
Concat(ConcatNode),
Dropout(DropoutNode),
GlobalAvgPool(GlobalAvgPoolNode),
}
macro_rules! match_all {
@ -98,6 +100,7 @@ macro_rules! match_all {
Node::Dropout(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Binary(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
}
}};
}
@ -123,6 +126,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::BatchNorm(_) => "batch_norm",
Node::Reshape(_) => "reshape",
Node::Dropout(_) => "dropout",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::Unary(unary) => unary.kind.as_str(),
Node::Binary(binary) => binary.binary_type.as_str(),
}

View File

@ -0,0 +1,215 @@
use proc_macro2::TokenStream;
use quote::quote;
use burn::record::PrecisionSettings;
use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
/// GlobalAvgPoolNode is a node that performs a global average pooling operation.
///
/// The node is implemented using the AdaptiveAvgPool1d or AdaptiveAvgPool2d module
/// depending on the input dimension. AdaptiveAvgPool with output size 1 or size (1,1)
/// is equivalent to global average pooling.
#[derive(Debug, Clone)]
pub struct GlobalAvgPoolNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
}
impl GlobalAvgPoolNode {
pub fn new<S: AsRef<str>>(name: S, input: TensorType, output: TensorType) -> Self {
// Depending on the input dimension, we need to use a different type nn module
let field_type = match input.dim {
3 => quote! {
AdaptiveAvgPool1d
},
4 => quote! {
AdaptiveAvgPool2d
},
dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"),
};
Self {
field: OtherType::new(name, field_type),
input,
output,
}
}
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for GlobalAvgPoolNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}
fn field_init(&self, _with_record: bool) -> Option<TokenStream> {
let name = &self.field.name;
let tokens = match self.input.dim {
3 => {
quote! {
let #name = AdaptiveAvgPool1dConfig::new(1)
.init();
}
}
4 => {
quote! {
let #name = AdaptiveAvgPool2dConfig::new([1,1])
.init();
}
}
dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"),
};
Some(tokens)
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;
quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
match self.input.dim {
3 => {
imports.register("burn::nn::pool::AdaptiveAvgPool1d");
imports.register("burn::nn::pool::AdaptiveAvgPool1dConfig");
}
4 => {
imports.register("burn::nn::pool::AdaptiveAvgPool2d");
imports.register("burn::nn::pool::AdaptiveAvgPool2dConfig");
}
dim => panic!("Unsupported input dim ({dim}) for GlobalAvgPoolNode"),
}
}
fn into_node(self) -> Node<PS> {
Node::GlobalAvgPool(self)
}
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{global_avg_pool::GlobalAvgPoolNode, test::assert_tokens},
TensorType,
};
use burn::record::FullPrecisionSettings;
#[test]
fn test_codegen_2d() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(GlobalAvgPoolNode::new(
"global_avg_pool1",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::pool::AdaptiveAvgPool2d;
use burn::nn::pool::AdaptiveAvgPool2dConfig;
#[derive(Module, Debug)]
pub struct Model <B: Backend> {
global_avg_pool1: AdaptiveAvgPool2d,
phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new_with(record: ModelRecord<B>) -> Self {
let global_avg_pool1 = AdaptiveAvgPool2dConfig::new([1, 1])
.init();
Self {
global_avg_pool1,
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.global_avg_pool1.forward(input);
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_1d() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(GlobalAvgPoolNode::new(
"global_avg_pool1",
TensorType::new_float("input", 3),
TensorType::new_float("output", 3),
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::pool::AdaptiveAvgPool1d;
use burn::nn::pool::AdaptiveAvgPool1dConfig;
#[derive(Module, Debug)]
pub struct Model <B: Backend> {
global_avg_pool1: AdaptiveAvgPool1d,
phantom: core::marker::PhantomData<B>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new_with(record: ModelRecord<B>) -> Self {
let global_avg_pool1 = AdaptiveAvgPool1dConfig::new(1)
.init();
Self {
global_avg_pool1,
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.global_avg_pool1.forward(input);
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -6,6 +6,7 @@ pub(crate) mod concat;
pub(crate) mod constant;
pub(crate) mod conv2d;
pub(crate) mod dropout;
pub(crate) mod global_avg_pool;
pub(crate) mod linear;
pub(crate) mod matmul;
pub(crate) mod max_pool2d;

View File

@ -90,8 +90,6 @@ pub fn dim_inference(
NodeType::Concat => concat_update_outputs(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Dropout => same_as_input(node),
//FIXME use correct output for GAP (@antimora 8/1/2023)
NodeType::GlobalAveragePool => same_as_input(node),
_ => todo!(
"shape inference for {:?} is not implemented",

View File

@ -566,21 +566,16 @@ fn rename_inputs(
counter += 1;
}
let mut counter: HashMap<String, usize> = HashMap::new();
for node in nodes.iter_mut() {
// keep track of the number of nodes of each type
counter
.entry(node.name.clone())
.and_modify(|e| *e += 1)
.or_insert(1);
let mut counter = 1;
// loop through node outputs and rename them and store the new name <-> old name mapping
for output in node.outputs.iter_mut() {
let old_name = output.name.clone();
let new_name = format!("{}_out{}", node.name, counter[&node.name]);
let new_name = format!("{}_out{}", node.name, counter);
output.name = new_name.clone();
old_names.insert(old_name, new_name);
counter += 1;
}
}
@ -600,7 +595,6 @@ fn rename_inputs(
if let Some(new_name) = old_names.get(&output.name) {
output.name = new_name.clone();
} else {
println!("{:#?}", old_names);
panic!("Output {} not found in old_names", output.name);
}
}

View File

@ -19,6 +19,7 @@ use crate::{
constant::{ConstantNode, ConstantValue, TensorValue},
conv2d::Conv2dNode,
dropout::DropoutNode,
global_avg_pool::GlobalAvgPoolNode,
linear::LinearNode,
matmul::MatmulNode,
max_pool2d::MaxPool2dNode,
@ -194,6 +195,9 @@ impl ONNXGraph {
NodeType::Concat => graph.register(Self::concat_conversion(node)),
NodeType::Cast => graph.register(Self::cast_conversion(node)),
NodeType::Dropout => graph.register(Self::dropout_conversion(node)),
NodeType::GlobalAveragePool => {
graph.register(Self::global_avg_pool_conversion(node))
}
_ => panic!("Unsupported node conversion {}", node.node_type),
}
}
@ -464,6 +468,15 @@ impl ONNXGraph {
let name = &node.name;
MaxPool2dNode::new(name, input, output, config)
}
fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode {
let input = node.inputs.get(0).unwrap().to_tensor_type();
let output = node.outputs.get(0).unwrap().to_tensor_type();
let name = &node.name;
GlobalAvgPoolNode::new(name, input, output)
}
}
fn extract_next_data_serialize<E: Element>(node: &mut Node) -> Option<DataSerialize<E>> {