mirror of https://github.com/tracel-ai/burn.git
Add shape ONNX op support (#1639)
* Add shape onnx op support * Remove cast node from onnx graph * Fix shape implementation * Fix shape config error message * Fix typo * Fix clippy type complexity for generated code
This commit is contained in:
parent
6d96e8d808
commit
35b36bbe62
|
@ -157,7 +157,7 @@ Here's how powf was added to burn fusion:
|
||||||
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
|
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
|
||||||
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
|
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
|
||||||
implementation for the situation where both sides of the operation were tensors. The `burn-wgpu`
|
implementation for the situation where both sides of the operation were tensors. The `burn-wgpu`
|
||||||
crate is primarily concered with how the operation is compiled and executed by the gpu. The actual
|
crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual
|
||||||
implementation is defined in `burn-jit`.
|
implementation is defined in `burn-jit`.
|
||||||
|
|
||||||
Here is where code was added for powf in `burn-jit` and `burn-wgpu`:
|
Here is where code was added for powf in `burn-jit` and `burn-wgpu`:
|
||||||
|
|
|
@ -164,7 +164,7 @@ represent the corresponding Burn Op.
|
||||||
| [SequenceInsert][157] | ❌ | ❌ |
|
| [SequenceInsert][157] | ❌ | ❌ |
|
||||||
| [SequenceLength][158] | ❌ | ❌ |
|
| [SequenceLength][158] | ❌ | ❌ |
|
||||||
| [SequenceMap][159] | ❌ | ❌ |
|
| [SequenceMap][159] | ❌ | ❌ |
|
||||||
| [Shape][160] | ❌ | ✅ |
|
| [Shape][160] | ✅ | ✅ |
|
||||||
| [Shrink][161] | ❌ | ❌ |
|
| [Shrink][161] | ❌ | ❌ |
|
||||||
| [Sigmoid][162] | ✅ | ✅ |
|
| [Sigmoid][162] | ✅ | ✅ |
|
||||||
| [Sign][163] | ❌ | ✅ |
|
| [Sign][163] | ❌ | ✅ |
|
||||||
|
|
|
@ -39,6 +39,7 @@ fn main() {
|
||||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||||
.input("tests/reduce_mean/reduce_mean.onnx")
|
.input("tests/reduce_mean/reduce_mean.onnx")
|
||||||
.input("tests/reshape/reshape.onnx")
|
.input("tests/reshape/reshape.onnx")
|
||||||
|
.input("tests/shape/shape.onnx")
|
||||||
.input("tests/sigmoid/sigmoid.onnx")
|
.input("tests/sigmoid/sigmoid.onnx")
|
||||||
.input("tests/sin/sin.onnx")
|
.input("tests/sin/sin.onnx")
|
||||||
.input("tests/softmax/softmax.onnx")
|
.input("tests/softmax/softmax.onnx")
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
macro_rules! include_models {
|
macro_rules! include_models {
|
||||||
($($model:ident),*) => {
|
($($model:ident),*) => {
|
||||||
$(
|
$(
|
||||||
|
// Allow type complexity for generated code
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
pub mod $model {
|
pub mod $model {
|
||||||
include!(concat!(env!("OUT_DIR"), concat!("/model/", stringify!($model), ".rs")));
|
include!(concat!(env!("OUT_DIR"), concat!("/model/", stringify!($model), ".rs")));
|
||||||
}
|
}
|
||||||
|
@ -46,6 +48,7 @@ include_models!(
|
||||||
reduce_mean,
|
reduce_mean,
|
||||||
relu,
|
relu,
|
||||||
reshape,
|
reshape,
|
||||||
|
shape,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
sin,
|
sin,
|
||||||
softmax,
|
softmax,
|
||||||
|
@ -476,6 +479,19 @@ mod tests {
|
||||||
assert_eq!(output.to_data(), expected);
|
assert_eq!(output.to_data(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn shape() {
|
||||||
|
let device = Default::default();
|
||||||
|
let model: shape::Model<Backend> = shape::Model::new(&device);
|
||||||
|
|
||||||
|
// Run the model
|
||||||
|
let input = Tensor::<Backend, 2>::ones([4, 2], &device);
|
||||||
|
let output = model.forward(input);
|
||||||
|
let expected = Data::from([4, 2]);
|
||||||
|
|
||||||
|
assert_eq!(output.to_data(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn flatten() {
|
fn flatten() {
|
||||||
// Initialize the model without weights (because the exported file does not contain them)
|
// Initialize the model without weights (because the exported file does not contain them)
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
pytorch2.1.2:K
|
||||||
|
|
||||||
|
x2Shape_0"Shape
|
||||||
|
main_graphZ
|
||||||
|
x
|
||||||
|
|
||||||
|
b
|
||||||
|
b
|
||||||
|
2
|
||||||
|
|
||||||
|
|
||||||
|
B
|
|
@ -0,0 +1,66 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# used to generate model: onnx-tests/tests/shape/shape.onnx
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# Trace with TorchScript to return the shape tensor (otherwise, would gather the shape
|
||||||
|
# of each dim as a scalar)
|
||||||
|
@torch.jit.script
|
||||||
|
def shape(x):
|
||||||
|
return torch.tensor(x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return shape(x)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
torch.set_printoptions(precision=8)
|
||||||
|
|
||||||
|
# Export to onnx
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model = Model()
|
||||||
|
model.eval()
|
||||||
|
test_input = torch.ones(4, 2, device=device)
|
||||||
|
file_name = "shape.onnx"
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
test_input,
|
||||||
|
file_name,
|
||||||
|
input_names=["x"],
|
||||||
|
dynamic_axes={"x": {0: "b"}},
|
||||||
|
verbose=False,
|
||||||
|
opset_version=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = onnx.load(file_name)
|
||||||
|
# Remove cast node
|
||||||
|
m.graph.node.pop(1)
|
||||||
|
m.graph.node[0].output[0] = m.graph.output[0].name
|
||||||
|
onnx.save(m, file_name)
|
||||||
|
|
||||||
|
print(f"Finished exporting model to {file_name}")
|
||||||
|
|
||||||
|
# Output some test data for use in the test
|
||||||
|
print(f"Test input data: {test_input}")
|
||||||
|
print(f"Test input data shape: {test_input.shape}")
|
||||||
|
output = model.forward(test_input)
|
||||||
|
# print(f"Test output data shape: {output.shape}")
|
||||||
|
|
||||||
|
print(f"Test output: {output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -35,6 +35,7 @@ pub enum UnaryNodeKind {
|
||||||
Reciprocal,
|
Reciprocal,
|
||||||
LeakyRelu,
|
LeakyRelu,
|
||||||
Relu,
|
Relu,
|
||||||
|
Shape,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
Sin,
|
Sin,
|
||||||
Softmax,
|
Softmax,
|
||||||
|
@ -60,6 +61,7 @@ impl UnaryNodeKind {
|
||||||
Self::Reciprocal => "reciprocal",
|
Self::Reciprocal => "reciprocal",
|
||||||
Self::LeakyRelu => "leaky_relu",
|
Self::LeakyRelu => "leaky_relu",
|
||||||
Self::Relu => "relu",
|
Self::Relu => "relu",
|
||||||
|
Self::Shape => "shape",
|
||||||
Self::Sigmoid => "sigmoid",
|
Self::Sigmoid => "sigmoid",
|
||||||
Self::Sin => "sin",
|
Self::Sin => "sin",
|
||||||
Self::Softmax => "softmax",
|
Self::Softmax => "softmax",
|
||||||
|
@ -123,6 +125,9 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
|
||||||
UnaryNodeKind::Neg => {
|
UnaryNodeKind::Neg => {
|
||||||
imports.register("core::ops::Neg");
|
imports.register("core::ops::Neg");
|
||||||
}
|
}
|
||||||
|
UnaryNodeKind::Shape => {
|
||||||
|
imports.register("burn::tensor::Int");
|
||||||
|
}
|
||||||
UnaryNodeKind::Not => {
|
UnaryNodeKind::Not => {
|
||||||
imports.register("burn::tensor::Bool");
|
imports.register("burn::tensor::Bool");
|
||||||
}
|
}
|
||||||
|
@ -314,6 +319,22 @@ impl UnaryNode {
|
||||||
panic!("ReduceMean only supports tensor output");
|
panic!("ReduceMean only supports tensor output");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
|
||||||
|
// Shape as defined by the ONNX op should return a tensor because other ops
|
||||||
|
// (e.g., Gather) will be used on a tensor
|
||||||
|
let function = move |input| {
|
||||||
|
quote! {
|
||||||
|
Tensor::<B, 1, Int>::from_data(
|
||||||
|
burn::tensor::Data::from(&#input.dims()[#start_dim..#end_dim])
|
||||||
|
.from_usize::<i64>()
|
||||||
|
.convert::<burn::tensor::ops::IntElem<B>>(),
|
||||||
|
&#input.device(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Self::new(input, output, UnaryNodeKind::Shape, Rc::new(function))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -784,4 +805,30 @@ mod tests {
|
||||||
vec!["tensor2".to_string()],
|
vec!["tensor2".to_string()],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unary_codegen_shape() {
|
||||||
|
one_node_graph(
|
||||||
|
UnaryNode::shape(
|
||||||
|
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||||
|
Type::Tensor(TensorType::new_int("tensor2", 1)),
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
quote! {
|
||||||
|
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1, Int> {
|
||||||
|
let tensor2 = Tensor::<B, 1, Int>::from_data(
|
||||||
|
burn::tensor::Data::from(&tensor1.dims()[1usize..3usize])
|
||||||
|
.from_usize::<i64>()
|
||||||
|
.convert::<burn::tensor::ops::IntElem<B>>(),
|
||||||
|
&tensor1.device(),
|
||||||
|
);
|
||||||
|
|
||||||
|
tensor2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
vec!["tensor1".to_string()],
|
||||||
|
vec!["tensor2".to_string()],
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -366,14 +366,17 @@ fn equal_update_outputs(node: &mut Node) {
|
||||||
|
|
||||||
fn shape_update_outputs(node: &mut Node) {
|
fn shape_update_outputs(node: &mut Node) {
|
||||||
if node.inputs.len() != 1 {
|
if node.inputs.len() != 1 {
|
||||||
panic!("Gather: multiple inputs are not supported: {:?}", node);
|
panic!("Shape: multiple inputs are not supported: {:?}", node);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the configuration of the linear layer (inputs are known)
|
|
||||||
let node_input = &mut node.inputs[0];
|
let node_input = &mut node.inputs[0];
|
||||||
if let ArgType::Tensor(tensor) = node_input.clone().ty {
|
if let ArgType::Tensor(_tensor) = node_input.clone().ty {
|
||||||
// Update the output tensor
|
// Output tensor is 1D int64
|
||||||
node.outputs[0].ty = ArgType::Shape(tensor.dim);
|
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||||
|
elem_type: ElementType::Int64,
|
||||||
|
dim: 1,
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
panic!("Only tensor input is valid");
|
panic!("Only tensor input is valid");
|
||||||
}
|
}
|
||||||
|
|
|
@ -704,3 +704,41 @@ pub fn reduce_mean_config(node: &Node) -> Option<usize> {
|
||||||
Some(dim as usize)
|
Some(dim as usize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn shape_config(curr: &Node) -> (usize, usize) {
|
||||||
|
if curr.inputs.len() != 1 {
|
||||||
|
panic!(
|
||||||
|
"Shape: multiple inputs are not supported (got {:?})",
|
||||||
|
curr.inputs.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the shape of the input tensor
|
||||||
|
let tensor = match curr.inputs.first().unwrap().clone().ty {
|
||||||
|
ArgType::Tensor(tensor) => tensor,
|
||||||
|
_ => panic!("Only tensor input is valid"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Default: all axes up to the last one (included)
|
||||||
|
let mut start_dim: i64 = 0;
|
||||||
|
let mut end_dim: i64 = tensor.dim as i64;
|
||||||
|
|
||||||
|
// Extract the attributes
|
||||||
|
for (key, value) in curr.attrs.iter() {
|
||||||
|
match key.as_str() {
|
||||||
|
"start" => start_dim = value.clone().into_i64(),
|
||||||
|
"end" => end_dim = value.clone().into_i64(),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If dim is negative, it is counted from the end
|
||||||
|
if start_dim < 0 {
|
||||||
|
start_dim += tensor.dim as i64;
|
||||||
|
}
|
||||||
|
if end_dim < 0 {
|
||||||
|
end_dim += tensor.dim as i64;
|
||||||
|
}
|
||||||
|
|
||||||
|
(start_dim as usize, end_dim as usize)
|
||||||
|
}
|
||||||
|
|
|
@ -256,6 +256,7 @@ impl OnnxGraph {
|
||||||
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
|
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
|
||||||
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
||||||
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
||||||
|
NodeType::Shape => graph.register(Self::shape_conversion(node)),
|
||||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||||
NodeType::Sin => graph.register(Self::sin_conversion(node)),
|
NodeType::Sin => graph.register(Self::sin_conversion(node)),
|
||||||
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
||||||
|
@ -474,6 +475,14 @@ impl OnnxGraph {
|
||||||
UnaryNode::reduce_mean(input, output, dim)
|
UnaryNode::reduce_mean(input, output, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn shape_conversion(node: Node) -> UnaryNode {
|
||||||
|
let input = node.inputs.first().unwrap().to_type();
|
||||||
|
let output = node.outputs.first().unwrap().to_type();
|
||||||
|
let (start_dim, end_dim) = shape_config(&node);
|
||||||
|
|
||||||
|
UnaryNode::shape(input, output, start_dim, end_dim)
|
||||||
|
}
|
||||||
|
|
||||||
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||||
|
|
Loading…
Reference in New Issue