Implement ONNX Pad Operator (#2007)

* Implement ONNX pad

* ONNX pad arguments fix

pad now requires 2 or more arguments
if the third argument is not given, it will default to 0

* fixing bug in input len fix

* change panic comment

Change panic comment from needing two inputs. This comes from the fact that the ONNX spec requires two necessary inputs but could have more two more optional argument.

---------

Co-authored-by: JC <you@example.com>
Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
This commit is contained in:
johnhuichen 2024-07-23 17:50:20 +00:00 committed by GitHub
parent 53c77ae646
commit 4a3fc9d4a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 369 additions and 4 deletions

View File

@ -43,6 +43,7 @@ fn main() {
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/pad/pad.onnx")
.input("tests/expand/expand.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")

View File

@ -55,6 +55,7 @@ include_models!(
mul,
neg,
not,
pad,
greater,
greater_or_equal,
less,
@ -1407,6 +1408,26 @@ mod tests {
output.assert_eq(&expected, true);
}
#[test]
fn pad() {
let device = Default::default();
let model: pad::Model<Backend> = pad::Model::new(&device);
let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device);
let output = model.forward(input).to_data();
let expected = TensorData::from([
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 1., 2., 0., 0., 0., 0.],
[0.0_f32, 0., 3., 4., 0., 0., 0., 0.],
[0.0_f32, 0., 5., 6., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
]);
output.assert_eq(&expected, true);
}
#[test]
fn greater() {
let device = Default::default();

Binary file not shown.

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/pad/pad.onnx
### Helper Functions ###
from pathlib import Path
from typing import Any
import numpy
from numpy.core.multiarray import dtype
import onnx
from onnx import ModelProto, TensorProto, ValueInfoProto
from onnx.reference import ReferenceEvaluator
from onnx.checker import check_model
from onnx.helper import (
make_model,
make_node,
make_graph,
)
def build_test_save(
name: str,
inputs: list[ValueInfoProto],
outputs: list[ValueInfoProto],
initializers: list[TensorProto] = [],
attributes: dict[str, Any] = {},
) -> None:
node_inputs = [input.name for input in inputs + initializers]
node_outputs = [output.name for output in outputs]
node = make_node(
name.capitalize(),
inputs=node_inputs,
outputs=node_outputs,
**attributes,
)
graph = make_graph(
nodes=[node],
name=f"{name.capitalize()}Graph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)
onnx_model = make_model(graph)
check_model(onnx_model)
run_tests(onnx_model)
onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx"))
class TestCase:
def __init__(
self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray
):
self.name = name
self.feeds = feeds
self.expected = expected
def test_model(self, model: ModelProto):
sess = ReferenceEvaluator(model)
result = numpy.array(sess.run(None, self.feeds))
if not numpy.array_equal(result, self.expected):
print(
f"""{self.name}
Expected result: {self.expected}
Got: {result}"""
)
raise Exception("Test failed")
def test_positive_pads(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2)
pads = numpy.array([1, 2, 3, 4], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array(
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]
]
)
TestCase("test_positive_constant_pads", feeds, expected).test_model(model)
def test_1d_input(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 5, dtype="float32")
pads = numpy.array([1, 2], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]])
TestCase("test_1d_input", feeds, expected).test_model(model)
def run_tests(model: ModelProto) -> None:
test_positive_pads(model)
test_1d_input(model)
# TODO: test_negative_pads
# TODO: support other modes: reflect, edge, wrap
### Helper Functions End ###
import numpy
from onnx import TensorProto, numpy_helper
from onnx.helper import make_tensor_value_info
def get_initializers() -> list[TensorProto]:
pads = numpy_helper.from_array(
numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads"
)
constant_value = numpy_helper.from_array(
numpy.array([0.0]).astype(numpy.float32), name="constant_value"
)
return [pads, constant_value]
def main() -> None:
name = "pad"
inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])]
outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])]
initializers = get_initializers()
build_test_save(
name=name,
inputs=inputs,
outputs=outputs,
initializers=initializers,
attributes={"mode": "constant"},
)
if __name__ == "__main__":
main()

View File

@ -8,7 +8,7 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
Resize(ResizeNode),
@ -150,6 +151,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Resize(node) => $func(node),
@ -203,6 +205,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Resize(_) => "resize",

View File

@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;

View File

@ -0,0 +1,104 @@
use std::str::FromStr;
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Config, Debug)]
pub struct PadConfig {
pub pads: Vec<usize>,
pub constant_value: f32,
}
#[derive(Debug, Clone, new)]
pub struct PadNode {
pub input: TensorType,
pub output: TensorType,
pub config: PadConfig,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for PadNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let pads = self.config.pads.iter().map(|p| p.to_tokens());
let constant_value_string = format!("{}_f32.elem()", self.config.constant_value);
let constant_value = TokenStream::from_str(&constant_value_string).unwrap();
quote! {
let #output = #input.pad((#(#pads),*), #constant_value);
}
}
fn into_node(self) -> Node<PS> {
Node::Pad(self)
}
fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
}
}
#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{pad::PadNode, test::assert_tokens},
TensorType,
};
#[test]
fn test_codegen_pad() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = PadConfig::new(vec![1, 2, 3, 4], -1.0);
graph.register(PadNode::new(
TensorType::new_float("input", 2),
TensorType::new_float("output", 2),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
let expected = quote! {
use burn::tensor::ElementConversion;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.pad((1, 2, 3, 4), -1_f32.elem());
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -7,7 +7,7 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};
use crate::burn::node::resize::ResizeMode;
use crate::burn::node::{pad::PadConfig, resize::ResizeMode};
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
/// Create a Conv1dConfig from the attributes of the node
@ -745,6 +745,72 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
)
}
/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads(node: &Node) -> Vec<usize> {
if node.inputs.len() < 2 {
panic!("Pad: must provide at least two inputs")
}
let input_dim = match &node.inputs.first().unwrap().ty {
ArgType::Tensor(tensor) => tensor.dim,
_ => panic!("Pad: Only tensor input is valid"),
};
let pads: Vec<usize> = match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape
.iter()
.map(|&x| {
if x < 0 {
// TODO: support negative pads
panic!("Pad: Negative pad is not supported");
}
x as usize
})
.collect(),
_ => panic!("Pad: pads data type must be int64"),
};
if pads.len() != input_dim * 2 {
panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]");
}
// TODO: Burn's pad should support 1D tensor
if input_dim < 2 {
panic!("Pad: input tensor should be rank 2 or higher");
}
let left_index = input_dim - 1;
let top_index = input_dim - 2;
let right_index = pads.len() - 1;
let bottom_index = pads.len() - 2;
let index_list = [left_index, top_index, right_index, bottom_index];
for (index, &item) in pads.iter().enumerate() {
if !index_list.contains(&index) && item != 0 {
panic!("Pad: padding will only be applied to the last two dimensions but found non zero padding for other dimensions");
}
}
let left = pads[left_index];
let top = pads[top_index];
let right = pads[right_index];
let bottom = pads[bottom_index];
vec![left, right, top, bottom]
}
fn get_constant_value(node: &Node) -> f32 {
// TODO: support int, boolean
match &node.inputs[2].value {
Some(Data::Float32s(shape)) => shape.first().unwrap().to_owned(),
_ => 0.0,
}
}
let pads = get_pads(node);
let constant_value = get_constant_value(node);
PadConfig::new(pads, constant_value)
}
/// Calculate the padding configuration for a 1D operations such as Convolution and Pooling.
///
/// # Arguments

View File

@ -40,6 +40,7 @@ use crate::{
matmul::MatmulNode,
max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode,
pad::PadNode,
prelu::PReluNode,
random_normal::RandomNormalNode,
random_uniform::RandomUniformNode,
@ -63,7 +64,7 @@ use super::op_configuration::{
concat_config, conv1d_config, conv2d_config, conv3d_config, conv_transpose2d_config,
conv_transpose3d_config, dropout_config, expand_config, flatten_config, gather_config,
layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config,
max_pool2d_config, reduce_max_config, reduce_mean_config, reduce_min_config,
max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config,
reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config,
slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config,
};
@ -324,6 +325,7 @@ impl ParsedOnnxGraph {
NodeType::ConvTranspose3d => {
graph.register(Self::conv_transpose3d_conversion::<PS>(node))
}
NodeType::Pad => graph.register(Self::pad_conversion(node)),
NodeType::Pow => graph.register(Self::pow_conversion(node)),
NodeType::Unsqueeze => graph.register(Self::unsqueeze_conversion(node)),
NodeType::Where => graph.register(Self::where_conversion(node)),
@ -1108,6 +1110,14 @@ impl ParsedOnnxGraph {
BinaryNode::lower_equal(lhs, rhs, output)
}
fn pad_conversion(node: Node) -> PadNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let config = pad_config(&node);
PadNode::new(input, output, config)
}
fn pow_conversion(node: Node) -> BinaryNode {
let lhs = Type::from(node.inputs.first().unwrap());
let rhs = Type::from(node.inputs.get(1).unwrap());

View File

@ -736,7 +736,7 @@ where
)
}
/// Pad the tensor with the given value on the last two dimensions.
/// Pad the tensor of rank two or higher with the given value on the last two dimensions.
///
/// # Arguments
///

View File

@ -48,6 +48,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
NodeType::Not => same_as_input(node),
NodeType::Pad => same_as_input(node),
NodeType::Greater => greater_update_outputs(node),
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
NodeType::Less => less_update_outputs(node),