Add shape ONNX op support (#1639)

* Add shape onnx op support

* Remove cast node from onnx graph

* Fix shape implementation

* Fix shape config error message

* Fix typo

* Fix clippy type complexity for generated code
This commit is contained in:
Guillaume Lagrange 2024-04-16 09:28:21 -04:00 committed by GitHub
parent 6d96e8d808
commit 35b36bbe62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 199 additions and 7 deletions

View File

@ -157,7 +157,7 @@ Here's how powf was added to burn fusion:
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
implementation for the situation where both sides of the operation were tensors. The `burn-wgpu` implementation for the situation where both sides of the operation were tensors. The `burn-wgpu`
crate is primarily concered with how the operation is compiled and executed by the gpu. The actual crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual
implementation is defined in `burn-jit`. implementation is defined in `burn-jit`.
Here is where code was added for powf in `burn-jit` and `burn-wgpu`: Here is where code was added for powf in `burn-jit` and `burn-wgpu`:

View File

@ -164,7 +164,7 @@ represent the corresponding Burn Op.
| [SequenceInsert][157] | ❌ | ❌ | | [SequenceInsert][157] | ❌ | ❌ |
| [SequenceLength][158] | ❌ | ❌ | | [SequenceLength][158] | ❌ | ❌ |
| [SequenceMap][159] | ❌ | ❌ | | [SequenceMap][159] | ❌ | ❌ |
| [Shape][160] | | ✅ | | [Shape][160] | | ✅ |
| [Shrink][161] | ❌ | ❌ | | [Shrink][161] | ❌ | ❌ |
| [Sigmoid][162] | ✅ | ✅ | | [Sigmoid][162] | ✅ | ✅ |
| [Sign][163] | ❌ | ✅ | | [Sign][163] | ❌ | ✅ |

View File

@ -39,6 +39,7 @@ fn main() {
.input("tests/leaky_relu/leaky_relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reshape/reshape.onnx") .input("tests/reshape/reshape.onnx")
.input("tests/shape/shape.onnx")
.input("tests/sigmoid/sigmoid.onnx") .input("tests/sigmoid/sigmoid.onnx")
.input("tests/sin/sin.onnx") .input("tests/sin/sin.onnx")
.input("tests/softmax/softmax.onnx") .input("tests/softmax/softmax.onnx")

View File

@ -4,6 +4,8 @@
macro_rules! include_models { macro_rules! include_models {
($($model:ident),*) => { ($($model:ident),*) => {
$( $(
// Allow type complexity for generated code
#[allow(clippy::type_complexity)]
pub mod $model { pub mod $model {
include!(concat!(env!("OUT_DIR"), concat!("/model/", stringify!($model), ".rs"))); include!(concat!(env!("OUT_DIR"), concat!("/model/", stringify!($model), ".rs")));
} }
@ -46,6 +48,7 @@ include_models!(
reduce_mean, reduce_mean,
relu, relu,
reshape, reshape,
shape,
sigmoid, sigmoid,
sin, sin,
softmax, softmax,
@ -476,6 +479,19 @@ mod tests {
assert_eq!(output.to_data(), expected); assert_eq!(output.to_data(), expected);
} }
#[test]
fn shape() {
let device = Default::default();
let model: shape::Model<Backend> = shape::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 2>::ones([4, 2], &device);
let output = model.forward(input);
let expected = Data::from([4, 2]);
assert_eq!(output.to_data(), expected);
}
#[test] #[test]
fn flatten() { fn flatten() {
// Initialize the model without weights (because the exported file does not contain them) // Initialize the model without weights (because the exported file does not contain them)

View File

@ -0,0 +1,12 @@
pytorch2.1.2:K

x2Shape_0"Shape
main_graphZ
x

b
b
2

B

View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/shape/shape.onnx
import onnx
import torch
import torch.nn as nn
# Trace with TorchScript to return the shape tensor (otherwise, would gather the shape
# of each dim as a scalar)
@torch.jit.script
def shape(x):
return torch.tensor(x.shape)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return shape(x)
def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)
# Export to onnx
device = torch.device("cpu")
model = Model()
model.eval()
test_input = torch.ones(4, 2, device=device)
file_name = "shape.onnx"
torch.onnx.export(
model,
test_input,
file_name,
input_names=["x"],
dynamic_axes={"x": {0: "b"}},
verbose=False,
opset_version=16,
)
m = onnx.load(file_name)
# Remove cast node
m.graph.node.pop(1)
m.graph.node[0].output[0] = m.graph.output[0].name
onnx.save(m, file_name)
print(f"Finished exporting model to {file_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
print(f"Test input data shape: {test_input.shape}")
output = model.forward(test_input)
# print(f"Test output data shape: {output.shape}")
print(f"Test output: {output}")
if __name__ == "__main__":
main()

View File

@ -35,6 +35,7 @@ pub enum UnaryNodeKind {
Reciprocal, Reciprocal,
LeakyRelu, LeakyRelu,
Relu, Relu,
Shape,
Sigmoid, Sigmoid,
Sin, Sin,
Softmax, Softmax,
@ -60,6 +61,7 @@ impl UnaryNodeKind {
Self::Reciprocal => "reciprocal", Self::Reciprocal => "reciprocal",
Self::LeakyRelu => "leaky_relu", Self::LeakyRelu => "leaky_relu",
Self::Relu => "relu", Self::Relu => "relu",
Self::Shape => "shape",
Self::Sigmoid => "sigmoid", Self::Sigmoid => "sigmoid",
Self::Sin => "sin", Self::Sin => "sin",
Self::Softmax => "softmax", Self::Softmax => "softmax",
@ -123,6 +125,9 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
UnaryNodeKind::Neg => { UnaryNodeKind::Neg => {
imports.register("core::ops::Neg"); imports.register("core::ops::Neg");
} }
UnaryNodeKind::Shape => {
imports.register("burn::tensor::Int");
}
UnaryNodeKind::Not => { UnaryNodeKind::Not => {
imports.register("burn::tensor::Bool"); imports.register("burn::tensor::Bool");
} }
@ -314,6 +319,22 @@ impl UnaryNode {
panic!("ReduceMean only supports tensor output"); panic!("ReduceMean only supports tensor output");
} }
} }
pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
// Shape as defined by the ONNX op should return a tensor because other ops
// (e.g., Gather) will be used on a tensor
let function = move |input| {
quote! {
Tensor::<B, 1, Int>::from_data(
burn::tensor::Data::from(&#input.dims()[#start_dim..#end_dim])
.from_usize::<i64>()
.convert::<burn::tensor::ops::IntElem<B>>(),
&#input.device(),
)
}
};
Self::new(input, output, UnaryNodeKind::Shape, Rc::new(function))
}
} }
#[cfg(test)] #[cfg(test)]
@ -784,4 +805,30 @@ mod tests {
vec!["tensor2".to_string()], vec!["tensor2".to_string()],
); );
} }
#[test]
fn test_unary_codegen_shape() {
one_node_graph(
UnaryNode::shape(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_int("tensor2", 1)),
1,
3,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1, Int> {
let tensor2 = Tensor::<B, 1, Int>::from_data(
burn::tensor::Data::from(&tensor1.dims()[1usize..3usize])
.from_usize::<i64>()
.convert::<burn::tensor::ops::IntElem<B>>(),
&tensor1.device(),
);
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
} }

View File

@ -366,14 +366,17 @@ fn equal_update_outputs(node: &mut Node) {
fn shape_update_outputs(node: &mut Node) { fn shape_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 { if node.inputs.len() != 1 {
panic!("Gather: multiple inputs are not supported: {:?}", node); panic!("Shape: multiple inputs are not supported: {:?}", node);
} }
// Extract the configuration of the linear layer (inputs are known)
let node_input = &mut node.inputs[0]; let node_input = &mut node.inputs[0];
if let ArgType::Tensor(tensor) = node_input.clone().ty { if let ArgType::Tensor(_tensor) = node_input.clone().ty {
// Update the output tensor // Output tensor is 1D int64
node.outputs[0].ty = ArgType::Shape(tensor.dim); node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Int64,
dim: 1,
..Default::default()
});
} else { } else {
panic!("Only tensor input is valid"); panic!("Only tensor input is valid");
} }

View File

@ -704,3 +704,41 @@ pub fn reduce_mean_config(node: &Node) -> Option<usize> {
Some(dim as usize) Some(dim as usize)
} }
} }
pub fn shape_config(curr: &Node) -> (usize, usize) {
if curr.inputs.len() != 1 {
panic!(
"Shape: multiple inputs are not supported (got {:?})",
curr.inputs.len()
);
}
// Extract the shape of the input tensor
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
// Default: all axes up to the last one (included)
let mut start_dim: i64 = 0;
let mut end_dim: i64 = tensor.dim as i64;
// Extract the attributes
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"start" => start_dim = value.clone().into_i64(),
"end" => end_dim = value.clone().into_i64(),
_ => {}
}
}
// If dim is negative, it is counted from the end
if start_dim < 0 {
start_dim += tensor.dim as i64;
}
if end_dim < 0 {
end_dim += tensor.dim as i64;
}
(start_dim as usize, end_dim as usize)
}

View File

@ -256,6 +256,7 @@ impl OnnxGraph {
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)), NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Shape => graph.register(Self::shape_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Sin => graph.register(Self::sin_conversion(node)), NodeType::Sin => graph.register(Self::sin_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)), NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
@ -474,6 +475,14 @@ impl OnnxGraph {
UnaryNode::reduce_mean(input, output, dim) UnaryNode::reduce_mean(input, output, dim)
} }
fn shape_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let (start_dim, end_dim) = shape_config(&node);
UnaryNode::shape(input, output, start_dim, end_dim)
}
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode { fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type(); let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type();