mirror of https://github.com/tracel-ai/burn.git
Feat: Allow onnx-import expand op with non-const shapes (#2189)
* Feat: Allow onnx-import expand op with non-const shapes * Generalize ONNX Expand across IntElem
This commit is contained in:
parent
7baa33bdaa
commit
2f4c5ac0a1
|
@ -31,6 +31,8 @@ fn main() {
|
|||
.input("tests/erf/erf.onnx")
|
||||
.input("tests/exp/exp.onnx")
|
||||
.input("tests/expand/expand.onnx")
|
||||
.input("tests/expand/expand_tensor.onnx")
|
||||
.input("tests/expand/expand_shape.onnx")
|
||||
.input("tests/flatten/flatten.onnx")
|
||||
.input("tests/gather/gather_1d_idx.onnx")
|
||||
.input("tests/gather/gather_2d_idx.onnx")
|
||||
|
|
|
@ -46,6 +46,10 @@ def main() -> None:
|
|||
# Create the model
|
||||
model_def = helper.make_model(graph_def, producer_name='expand')
|
||||
|
||||
|
||||
# Ensure valid ONNX:
|
||||
onnx.checker.check_model(model_def)
|
||||
|
||||
# Save the model to a file
|
||||
onnx.save(model_def, 'expand.onnx')
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
pytorch2.3.0:®
|
||||
+
|
||||
shape_src/Shape_output_0/Shape"Shape
|
||||
*
|
||||
inp
|
||||
/Shape_output_03/Expand"Expand
|
||||
main_graphZ
|
||||
inp
|
||||
|
||||
|
||||
Z
|
||||
shape_src
|
||||
|
||||
|
||||
b
|
||||
3
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/shape/shape.onnx
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, inp: torch.Tensor, shape_src: torch.Tensor) -> torch.Tensor:
|
||||
return inp.expand_as(shape_src)
|
||||
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
device = torch.device("cpu")
|
||||
model = Model()
|
||||
model.eval()
|
||||
test_input = torch.ones(4, 1, device=device)
|
||||
test_shape_src = torch.ones(4, 4, device=device)
|
||||
file_name = "expand_shape.onnx"
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(test_input, test_shape_src),
|
||||
file_name,
|
||||
input_names=["inp", "shape_src"],
|
||||
verbose=False,
|
||||
opset_version=16,
|
||||
)
|
||||
|
||||
print(f"Finished exporting model to {file_name}")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data: {test_input}")
|
||||
print(f"Test input data shape: {test_input.shape}")
|
||||
print(f"Test shape source tensor shape: {test_input.shape}")
|
||||
output = model.forward(test_input, test_shape_src)
|
||||
print(f"Test output data shape: {output.shape}")
|
||||
|
||||
print(f"Test output: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,17 @@
|
|||
|
||||
expand:Œ
|
||||
.
|
||||
input_tensor
|
||||
shapeoutput/Expand"ExpandExpandGraphZ
|
||||
input_tensor
|
||||
|
||||
|
||||
Z
|
||||
shape
|
||||
|
||||
|
||||
b
|
||||
output
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/expand/expand.onnx
|
||||
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
|
||||
def main() -> None:
|
||||
# Define the Expand node that uses the outputs from the constant nodes
|
||||
expand_node = helper.make_node(
|
||||
'Expand',
|
||||
name='/Expand',
|
||||
inputs=['input_tensor', 'shape'],
|
||||
outputs=['output']
|
||||
)
|
||||
|
||||
# Create the graph
|
||||
graph_def = helper.make_graph(
|
||||
nodes=[expand_node],
|
||||
name='ExpandGraph',
|
||||
inputs=[
|
||||
helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2, 1]),
|
||||
helper.make_tensor_value_info('shape', TensorProto.INT64, [2]),
|
||||
],
|
||||
outputs=[
|
||||
helper.make_tensor_value_info('output', TensorProto.FLOAT, [2,2])
|
||||
],
|
||||
)
|
||||
|
||||
# Create the model
|
||||
model_def = helper.make_model(graph_def, producer_name='expand')
|
||||
|
||||
|
||||
# Ensure valid ONNX:
|
||||
onnx.checker.check_model(model_def)
|
||||
|
||||
# Save the model to a file
|
||||
onnx.save(model_def, 'expand_tensor.onnx')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -40,6 +40,8 @@ include_models!(
|
|||
erf,
|
||||
exp,
|
||||
expand,
|
||||
expand_tensor,
|
||||
expand_shape,
|
||||
flatten,
|
||||
gather_1d_idx,
|
||||
gather_2d_idx,
|
||||
|
@ -1584,6 +1586,34 @@ mod tests {
|
|||
assert_eq!(output.shape(), expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_tensor() {
|
||||
let device = Default::default();
|
||||
let model: expand_tensor::Model<Backend> = expand_tensor::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0]], &device);
|
||||
let input2 = Tensor::<Backend, 1, Int>::from_ints([2, 2], &device);
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected_shape = Shape::from([2, 2]);
|
||||
|
||||
assert_eq!(output.shape(), expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_shape() {
|
||||
let device = Default::default();
|
||||
let model: expand_shape::Model<Backend> = expand_shape::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[-1.0], [1.0], [1.0], [1.0]], &device);
|
||||
let input2 = Tensor::<Backend, 2>::zeros([4, 4], &device);
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected_shape = Shape::from([4, 4]);
|
||||
|
||||
assert_eq!(output.shape(), expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu() {
|
||||
let device = Default::default();
|
||||
|
|
|
@ -555,15 +555,19 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
|
||||
// Get the input and output types of the graph using passed in names
|
||||
input_names.iter().for_each(|input| {
|
||||
self.graph_input_types
|
||||
.push(inputs.get(&TensorType::format_name(input)).unwrap().clone());
|
||||
self.graph_input_types.push(
|
||||
inputs
|
||||
.get(&TensorType::format_name(input))
|
||||
.unwrap_or_else(|| panic!("Input type not found for {input}"))
|
||||
.clone(),
|
||||
);
|
||||
});
|
||||
|
||||
output_names.iter().for_each(|output| {
|
||||
self.graph_output_types.push(
|
||||
outputs
|
||||
.get(&TensorType::format_name(output))
|
||||
.unwrap_or_else(|| panic!("Output type is not found for {output}"))
|
||||
.unwrap_or_else(|| panic!("Output type not found for {output}"))
|
||||
.clone(),
|
||||
);
|
||||
});
|
||||
|
|
|
@ -8,7 +8,13 @@ use quote::quote;
|
|||
pub struct ExpandNode {
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub shape: Vec<i64>,
|
||||
pub shape: ExpandShape,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ExpandShape {
|
||||
Static(Vec<i64>),
|
||||
Runtime(Type),
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
|
||||
|
@ -17,14 +23,40 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
|
|||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
let input = Type::Tensor(self.input.clone());
|
||||
// If the shape is static, we only have the input tensor as an input,
|
||||
// if it is dynamic, the shape will be our 2nd:
|
||||
match &self.shape {
|
||||
ExpandShape::Static(_) => vec![input],
|
||||
ExpandShape::Runtime(rt_type) => vec![input, rt_type.clone()],
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let shape = &self.shape.to_tokens();
|
||||
let output = &self.output.name;
|
||||
|
||||
let shape = match &self.shape {
|
||||
ExpandShape::Static(static_shape) => static_shape.to_tokens(),
|
||||
ExpandShape::Runtime(Type::Tensor(shape_tensor)) => {
|
||||
// since we don't take ownership of the shape_tensor, we don't need `tensor_use_owned` here:
|
||||
let tensor_name = &shape_tensor.name;
|
||||
let dim = shape_tensor.shape.as_ref().unwrap()[0];
|
||||
// the shape of the tensor is already validated statically to be rank one when parsing the input
|
||||
// we'll need to download the Tensor from device to cpu for expand operation.
|
||||
// Also, we'll need to convert it to an array for conversion into BroadcastArgs
|
||||
quote! {
|
||||
TryInto::<[B::IntElem; #dim]>::try_into(#tensor_name.to_data().as_slice::<B::IntElem>().unwrap()).unwrap()
|
||||
}
|
||||
}
|
||||
ExpandShape::Runtime(Type::Shape(shape)) => {
|
||||
// Shape implements BroadcastArgs, so it can be passed to expand directly
|
||||
let shape_name = &shape.name;
|
||||
quote! { #shape_name }
|
||||
}
|
||||
_ => panic!("Invalid shape source {:?}", self.shape),
|
||||
};
|
||||
|
||||
quote! {
|
||||
let #output = #input.expand(#shape);
|
||||
}
|
||||
|
@ -43,17 +75,17 @@ mod tests {
|
|||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{expand::ExpandNode, test::assert_tokens},
|
||||
TensorType,
|
||||
ShapeType, TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_nodes() {
|
||||
fn test_codegen_expand_static() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(ExpandNode::new(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
[4, 4, 4, 4].into(),
|
||||
ExpandShape::Static([4, 4, 4, 4].into()),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
|
||||
|
@ -89,4 +121,112 @@ mod tests {
|
|||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_expand_shape() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(ExpandNode::new(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
ExpandShape::Runtime(Type::Shape(ShapeType::new("shape1", 4))),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "shape1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 4>,
|
||||
shape1: [usize; 4],
|
||||
) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.expand(shape1);
|
||||
|
||||
tensor2
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_expand_tensor() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
let mut shape_tensor_type = TensorType::new_int("tensor3", 4);
|
||||
shape_tensor_type.shape = Some(vec![4]);
|
||||
|
||||
graph.register(ExpandNode::new(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
ExpandShape::Runtime(Type::Tensor(shape_tensor_type)),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor3".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::tensor::Int;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 4>,
|
||||
tensor3: Tensor<B, 4, Int>,
|
||||
) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.expand(
|
||||
TryInto::<[B::IntElem; 4usize]>::try_into(tensor3.to_data().as_slice::<B::IntElem>().unwrap())
|
||||
.unwrap(),
|
||||
);
|
||||
tensor2
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@ use burn::nn::{
|
|||
PaddingConfig2d, PaddingConfig3d,
|
||||
};
|
||||
|
||||
use crate::burn::node::{pad::PadConfig, tile::TileConfig};
|
||||
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
|
||||
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig};
|
||||
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};
|
||||
|
||||
/// Create a Conv1dConfig from the attributes of the node
|
||||
pub fn conv1d_config(curr: &Node) -> Conv1dConfig {
|
||||
|
@ -382,19 +382,34 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig {
|
|||
.with_count_include_pad(count_include_pad == 1)
|
||||
}
|
||||
|
||||
pub fn expand_config(node: &Node) -> Vec<i64> {
|
||||
pub fn expand_config(node: &Node) -> ExpandShape {
|
||||
let input_value = &node.inputs[1].value;
|
||||
match &node.inputs[1].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
assert_eq!(tensor.dim, 1, "Expand: shape tensor must be 1D");
|
||||
if let Some(Data::Int64s(shape)) = input_value.as_ref() {
|
||||
shape.clone()
|
||||
} else {
|
||||
panic!("Tensor data type must be int64")
|
||||
}
|
||||
assert!(
|
||||
tensor.shape.is_some(),
|
||||
"Expand: shape tensor shape must be known!"
|
||||
);
|
||||
assert!(
|
||||
matches!(tensor.elem_type, ElementType::Int64),
|
||||
"Expand: shape tensor must have element type int64"
|
||||
);
|
||||
}
|
||||
ArgType::Shape(_) => {
|
||||
// Shapes are always 1-D int64 data, so nothing to assert here
|
||||
}
|
||||
_ => panic!("Only tensor input is valid for shape"),
|
||||
}
|
||||
|
||||
match input_value.as_ref() {
|
||||
Some(Data::Int64s(shape)) => ExpandShape::Static(shape.clone()),
|
||||
None => {
|
||||
// we were unable to statically determine the input value, so we'll need to fetch it at runtime
|
||||
ExpandShape::Runtime(crate::burn::Type::from(&node.inputs[1]))
|
||||
}
|
||||
_ => panic!("Shape data type must be int64, is {:?}", input_value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a FlattenConfig from the attributes of the node
|
||||
|
|
|
@ -30,7 +30,7 @@ use crate::{
|
|||
conv_transpose_2d::ConvTranspose2dNode,
|
||||
conv_transpose_3d::ConvTranspose3dNode,
|
||||
dropout::DropoutNode,
|
||||
expand::ExpandNode,
|
||||
expand::{ExpandNode, ExpandShape},
|
||||
gather::GatherNode,
|
||||
gather_elements::GatherElementsNode,
|
||||
global_avg_pool::GlobalAvgPoolNode,
|
||||
|
@ -1084,8 +1084,14 @@ impl ParsedOnnxGraph {
|
|||
|
||||
fn expand_conversion(node: Node) -> ExpandNode {
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let mut output = TensorType::from(node.outputs.first().unwrap());
|
||||
let shape = expand_config(&node);
|
||||
output.dim = match &shape {
|
||||
ExpandShape::Static(s) => s.len(),
|
||||
ExpandShape::Runtime(Type::Shape(s)) => s.dim,
|
||||
ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0],
|
||||
_ => panic!("Invalid ExpandShape {shape:?}!"),
|
||||
};
|
||||
|
||||
ExpandNode::new(input, output, shape)
|
||||
}
|
||||
|
|
|
@ -2365,23 +2365,26 @@ impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [i32; D2] {
|
||||
impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
|
||||
// Passing -1 as the size for a dimension means not changing the size of that dimension.
|
||||
fn into_shape(self, shape: &Shape<D1>) -> Shape<D2> {
|
||||
if self.len() < shape.dims.len() {
|
||||
panic!("Broadcast arguments must be greater than the number of dimensions");
|
||||
}
|
||||
|
||||
if self.iter().any(|&x| x < -1 || x == 0) {
|
||||
panic!("Broadcast arguments must be positive or -1");
|
||||
}
|
||||
|
||||
// Zip the two shapes in reverse order and replace -1 with the actual dimension value.
|
||||
let new_shape: Vec<_> = self
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|x| {
|
||||
let primitive = x.to_i64();
|
||||
if primitive < -1 || primitive == 0 {
|
||||
panic!("Broadcast arguments must be positive or -1");
|
||||
}
|
||||
primitive
|
||||
})
|
||||
.zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
|
||||
.map(|(&x, &y)| if x == -1 { y } else { x as usize })
|
||||
.map(|(x, &y)| if x == -1 { y } else { x as usize })
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
|
|
Loading…
Reference in New Issue