Feat: Allow onnx-import expand op with non-const shapes (#2189)

* Feat: Allow onnx-import expand op with non-const shapes

* Generalize ONNX Expand across IntElem
This commit is contained in:
Adrian Müller 2024-08-29 19:15:44 +02:00 committed by GitHub
parent 7baa33bdaa
commit 2f4c5ac0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 359 additions and 25 deletions

View File

@ -31,6 +31,8 @@ fn main() {
.input("tests/erf/erf.onnx")
.input("tests/exp/exp.onnx")
.input("tests/expand/expand.onnx")
.input("tests/expand/expand_tensor.onnx")
.input("tests/expand/expand_shape.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather_1d_idx.onnx")
.input("tests/gather/gather_2d_idx.onnx")

View File

@ -46,6 +46,10 @@ def main() -> None:
# Create the model
model_def = helper.make_model(graph_def, producer_name='expand')
# Ensure valid ONNX:
onnx.checker.check_model(model_def)
# Save the model to a file
onnx.save(model_def, 'expand.onnx')

View File

@ -0,0 +1,19 @@
pytorch2.3.0:®
+
shape_src/Shape_output_0/Shape"Shape
*
inp
/Shape_output_03/Expand"Expand
main_graphZ
inp


Z
shape_src


b
3


B

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/shape/shape.onnx
import onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, inp: torch.Tensor, shape_src: torch.Tensor) -> torch.Tensor:
return inp.expand_as(shape_src)
def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)
# Export to onnx
device = torch.device("cpu")
model = Model()
model.eval()
test_input = torch.ones(4, 1, device=device)
test_shape_src = torch.ones(4, 4, device=device)
file_name = "expand_shape.onnx"
torch.onnx.export(
model,
(test_input, test_shape_src),
file_name,
input_names=["inp", "shape_src"],
verbose=False,
opset_version=16,
)
print(f"Finished exporting model to {file_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
print(f"Test input data shape: {test_input.shape}")
print(f"Test shape source tensor shape: {test_input.shape}")
output = model.forward(test_input, test_shape_src)
print(f"Test output data shape: {output.shape}")
print(f"Test output: {output}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,17 @@

expand:Œ
.
input_tensor
shapeoutput/Expand"Expand ExpandGraphZ
input_tensor


Z
shape

b
output


B

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/expand/expand.onnx
import onnx
from onnx import helper, TensorProto
def main() -> None:
# Define the Expand node that uses the outputs from the constant nodes
expand_node = helper.make_node(
'Expand',
name='/Expand',
inputs=['input_tensor', 'shape'],
outputs=['output']
)
# Create the graph
graph_def = helper.make_graph(
nodes=[expand_node],
name='ExpandGraph',
inputs=[
helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2, 1]),
helper.make_tensor_value_info('shape', TensorProto.INT64, [2]),
],
outputs=[
helper.make_tensor_value_info('output', TensorProto.FLOAT, [2,2])
],
)
# Create the model
model_def = helper.make_model(graph_def, producer_name='expand')
# Ensure valid ONNX:
onnx.checker.check_model(model_def)
# Save the model to a file
onnx.save(model_def, 'expand_tensor.onnx')
if __name__ == '__main__':
main()

View File

@ -40,6 +40,8 @@ include_models!(
erf,
exp,
expand,
expand_tensor,
expand_shape,
flatten,
gather_1d_idx,
gather_2d_idx,
@ -1584,6 +1586,34 @@ mod tests {
assert_eq!(output.shape(), expected_shape);
}
#[test]
fn expand_tensor() {
let device = Default::default();
let model: expand_tensor::Model<Backend> = expand_tensor::Model::new(&device);
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0]], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([2, 2], &device);
let output = model.forward(input1, input2);
let expected_shape = Shape::from([2, 2]);
assert_eq!(output.shape(), expected_shape);
}
#[test]
fn expand_shape() {
let device = Default::default();
let model: expand_shape::Model<Backend> = expand_shape::Model::new(&device);
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0], [1.0], [1.0]], &device);
let input2 = Tensor::<Backend, 2>::zeros([4, 4], &device);
let output = model.forward(input1, input2);
let expected_shape = Shape::from([4, 4]);
assert_eq!(output.shape(), expected_shape);
}
#[test]
fn gelu() {
let device = Default::default();

View File

@ -555,15 +555,19 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
// Get the input and output types of the graph using passed in names
input_names.iter().for_each(|input| {
self.graph_input_types
.push(inputs.get(&TensorType::format_name(input)).unwrap().clone());
self.graph_input_types.push(
inputs
.get(&TensorType::format_name(input))
.unwrap_or_else(|| panic!("Input type not found for {input}"))
.clone(),
);
});
output_names.iter().for_each(|output| {
self.graph_output_types.push(
outputs
.get(&TensorType::format_name(output))
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
.unwrap_or_else(|| panic!("Output type not found for {output}"))
.clone(),
);
});

View File

@ -8,7 +8,13 @@ use quote::quote;
pub struct ExpandNode {
pub input: TensorType,
pub output: TensorType,
pub shape: Vec<i64>,
pub shape: ExpandShape,
}
#[derive(Debug, Clone)]
pub enum ExpandShape {
Static(Vec<i64>),
Runtime(Type),
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
@ -17,14 +23,40 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
let input = Type::Tensor(self.input.clone());
// If the shape is static, we only have the input tensor as an input,
// if it is dynamic, the shape will be our 2nd:
match &self.shape {
ExpandShape::Static(_) => vec![input],
ExpandShape::Runtime(rt_type) => vec![input, rt_type.clone()],
}
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let shape = &self.shape.to_tokens();
let output = &self.output.name;
let shape = match &self.shape {
ExpandShape::Static(static_shape) => static_shape.to_tokens(),
ExpandShape::Runtime(Type::Tensor(shape_tensor)) => {
// since we don't take ownership of the shape_tensor, we don't need `tensor_use_owned` here:
let tensor_name = &shape_tensor.name;
let dim = shape_tensor.shape.as_ref().unwrap()[0];
// the shape of the tensor is already validated statically to be rank one when parsing the input
// we'll need to download the Tensor from device to cpu for expand operation.
// Also, we'll need to convert it to an array for conversion into BroadcastArgs
quote! {
TryInto::<[B::IntElem; #dim]>::try_into(#tensor_name.to_data().as_slice::<B::IntElem>().unwrap()).unwrap()
}
}
ExpandShape::Runtime(Type::Shape(shape)) => {
// Shape implements BroadcastArgs, so it can be passed to expand directly
let shape_name = &shape.name;
quote! { #shape_name }
}
_ => panic!("Invalid shape source {:?}", self.shape),
};
quote! {
let #output = #input.expand(#shape);
}
@ -43,17 +75,17 @@ mod tests {
use crate::burn::{
graph::BurnGraph,
node::{expand::ExpandNode, test::assert_tokens},
TensorType,
ShapeType, TensorType,
};
#[test]
fn test_codegen_nodes() {
fn test_codegen_expand_static() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(ExpandNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
[4, 4, 4, 4].into(),
ExpandShape::Static([4, 4, 4, 4].into()),
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
@ -89,4 +121,112 @@ mod tests {
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_expand_shape() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(ExpandNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
ExpandShape::Runtime(Type::Shape(ShapeType::new("shape1", 4))),
));
graph.register_input_output(
vec!["tensor1".to_string(), "shape1".to_string()],
vec!["tensor2".to_string()],
);
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
shape1: [usize; 4],
) -> Tensor<B, 4> {
let tensor2 = tensor1.expand(shape1);
tensor2
}
}
};
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_expand_tensor() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let mut shape_tensor_type = TensorType::new_int("tensor3", 4);
shape_tensor_type.shape = Some(vec![4]);
graph.register(ExpandNode::new(
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
ExpandShape::Runtime(Type::Tensor(shape_tensor_type)),
));
graph.register_input_output(
vec!["tensor1".to_string(), "tensor3".to_string()],
vec!["tensor2".to_string()],
);
let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor3: Tensor<B, 4, Int>,
) -> Tensor<B, 4> {
let tensor2 = tensor1.expand(
TryInto::<[B::IntElem; 4usize]>::try_into(tensor3.to_data().as_slice::<B::IntElem>().unwrap())
.unwrap(),
);
tensor2
}
}
};
assert_tokens(graph.codegen(), expected);
}
}

View File

@ -7,8 +7,8 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};
use crate::burn::node::{pad::PadConfig, tile::TileConfig};
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig};
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};
/// Create a Conv1dConfig from the attributes of the node
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
@ -382,19 +382,34 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
.with_count_include_pad(count_include_pad == 1)
}
pub fn expand_config(node: &Node) -> Vec<i64> {
pub fn expand_config(node: &Node) -> ExpandShape {
let input_value = &node.inputs[1].value;
match &node.inputs[1].ty {
ArgType::Tensor(tensor) => {
assert_eq!(tensor.dim, 1, "Expand: shape tensor must be 1D");
if let Some(Data::Int64s(shape)) = input_value.as_ref() {
shape.clone()
} else {
panic!("Tensor data type must be int64")
}
assert!(
tensor.shape.is_some(),
"Expand: shape tensor shape must be known!"
);
assert!(
matches!(tensor.elem_type, ElementType::Int64),
"Expand: shape tensor must have element type int64"
);
}
ArgType::Shape(_) => {
// Shapes are always 1-D int64 data, so nothing to assert here
}
_ => panic!("Only tensor input is valid for shape"),
}
match input_value.as_ref() {
Some(Data::Int64s(shape)) => ExpandShape::Static(shape.clone()),
None => {
// we were unable to statically determine the input value, so we'll need to fetch it at runtime
ExpandShape::Runtime(crate::burn::Type::from(&node.inputs[1]))
}
_ => panic!("Shape data type must be int64, is {:?}", input_value),
}
}
/// Create a FlattenConfig from the attributes of the node

View File

@ -30,7 +30,7 @@ use crate::{
conv_transpose_2d::ConvTranspose2dNode,
conv_transpose_3d::ConvTranspose3dNode,
dropout::DropoutNode,
expand::ExpandNode,
expand::{ExpandNode, ExpandShape},
gather::GatherNode,
gather_elements::GatherElementsNode,
global_avg_pool::GlobalAvgPoolNode,
@ -1084,8 +1084,14 @@ impl ParsedOnnxGraph {
fn expand_conversion(node: Node) -> ExpandNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let mut output = TensorType::from(node.outputs.first().unwrap());
let shape = expand_config(&node);
output.dim = match &shape {
ExpandShape::Static(s) => s.len(),
ExpandShape::Runtime(Type::Shape(s)) => s.dim,
ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0],
_ => panic!("Invalid ExpandShape {shape:?}!"),
};
ExpandNode::new(input, output, shape)
}

View File

@ -2365,23 +2365,26 @@ impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
}
}
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [i32; D2] {
impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
// Passing -1 as the size for a dimension means not changing the size of that dimension.
fn into_shape(self, shape: &Shape<D1>) -> Shape<D2> {
if self.len() < shape.dims.len() {
panic!("Broadcast arguments must be greater than the number of dimensions");
}
if self.iter().any(|&x| x < -1 || x == 0) {
panic!("Broadcast arguments must be positive or -1");
}
// Zip the two shapes in reverse order and replace -1 with the actual dimension value.
let new_shape: Vec<_> = self
.iter()
.rev()
.map(|x| {
let primitive = x.to_i64();
if primitive < -1 || primitive == 0 {
panic!("Broadcast arguments must be positive or -1");
}
primitive
})
.zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
.map(|(&x, &y)| if x == -1 { y } else { x as usize })
.map(|(x, &y)| if x == -1 { y } else { x as usize })
.collect::<Vec<_>>()
.into_iter()
.rev()