mirror of https://github.com/tracel-ai/burn.git
feat: added slice onnx import (#1856)
* feat: added slice onnx import * fix: axes, steps handling
This commit is contained in:
parent
dd60446946
commit
671ec8c679
|
@ -171,7 +171,7 @@ represent the corresponding Burn Op.
|
|||
| [Sin][164] | ✅ | ✅ |
|
||||
| [Sinh][165] | ❌ | ❌ |
|
||||
| [Size][166] | ❌ | ❌ |
|
||||
| [Slice][167] | ❌ | ✅ |
|
||||
| [Slice][167] | ✅ | ✅ |
|
||||
| [Softmax][168] | ✅ | ✅ |
|
||||
| [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ |
|
||||
| [Softplus][170] | ❌ | ❌ |
|
||||
|
|
|
@ -69,6 +69,7 @@ fn main() {
|
|||
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
|
||||
.input("tests/pow/pow.onnx")
|
||||
.input("tests/pow/pow_int.onnx")
|
||||
.input("tests/slice/slice.onnx")
|
||||
.input("tests/sum/sum.onnx")
|
||||
.input("tests/sum/sum_int.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze.onnx")
|
||||
|
|
|
@ -71,6 +71,7 @@ include_models!(
|
|||
sigmoid,
|
||||
sign,
|
||||
sin,
|
||||
slice,
|
||||
softmax,
|
||||
sqrt,
|
||||
sub_int,
|
||||
|
@ -459,6 +460,24 @@ mod tests {
|
|||
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice() {
|
||||
let model: slice::Model<Backend> = slice::Model::default();
|
||||
let device = Default::default();
|
||||
|
||||
let input = Tensor::<Backend, 2>::from_floats(
|
||||
[
|
||||
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
|
||||
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[1., 2., 3., 4., 5.]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,101 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/slice/slice.onnx
|
||||
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
|
||||
def main() -> None:
|
||||
# Starts
|
||||
starts_val = [0,0] # Example shape value
|
||||
starts_tensor = helper.make_tensor(
|
||||
name="starts",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[len(starts_val)],
|
||||
vals=starts_val,
|
||||
)
|
||||
starts_node = helper.make_node(
|
||||
"Constant",
|
||||
name="starts_constant",
|
||||
inputs=[],
|
||||
outputs=["starts"],
|
||||
value=starts_tensor,
|
||||
)
|
||||
|
||||
# Ends
|
||||
ends_val = [1,5] # Example shape value
|
||||
ends_tensor = helper.make_tensor(
|
||||
name="ends",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[len(ends_val)],
|
||||
vals=ends_val,
|
||||
)
|
||||
ends_node = helper.make_node(
|
||||
"Constant",
|
||||
name="ends_constant",
|
||||
inputs=[],
|
||||
outputs=["ends"],
|
||||
value=ends_tensor,
|
||||
)
|
||||
|
||||
# Axes
|
||||
axes_val = [0,1] # Example shape value
|
||||
axes_tensor = helper.make_tensor(
|
||||
name="axes",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[len(axes_val)],
|
||||
vals=axes_val,
|
||||
)
|
||||
axes_node = helper.make_node(
|
||||
"Constant",
|
||||
name="axes_constant",
|
||||
inputs=[],
|
||||
outputs=["axes"],
|
||||
value=axes_tensor,
|
||||
)
|
||||
|
||||
# Steps
|
||||
steps_val = [1, 1] # Example shape value
|
||||
steps_tensor = helper.make_tensor(
|
||||
name="steps",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=[len(steps_val)],
|
||||
vals=steps_val,
|
||||
)
|
||||
steps_node = helper.make_node(
|
||||
"Constant",
|
||||
name="steps_constant",
|
||||
inputs=[],
|
||||
outputs=["steps"],
|
||||
value=steps_tensor,
|
||||
)
|
||||
|
||||
# Define the Slice node that uses the outputs from the constant nodes
|
||||
slice_node = helper.make_node(
|
||||
"Slice",
|
||||
name="slice_node",
|
||||
inputs=["input_tensor", "starts", "ends", "axes", "steps"],
|
||||
outputs=["output"],
|
||||
)
|
||||
|
||||
# Create the graph
|
||||
graph_def = helper.make_graph(
|
||||
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
|
||||
name="SliceGraph",
|
||||
inputs=[
|
||||
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
|
||||
],
|
||||
outputs=[
|
||||
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
|
||||
],
|
||||
)
|
||||
|
||||
# Create the model
|
||||
model_def = helper.make_model(graph_def, producer_name="slice")
|
||||
|
||||
# Save the model to a file
|
||||
onnx.save(model_def, "slice.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -7,7 +7,7 @@ use super::{
|
|||
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
|
||||
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
|
||||
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
|
||||
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
|
||||
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
|
||||
unsqueeze::UnsqueezeNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
|
@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
MaxPool2d(MaxPool2dNode),
|
||||
Range(RangeNode),
|
||||
Reshape(ReshapeNode),
|
||||
Slice(SliceNode),
|
||||
Squeeze(SqueezeNode),
|
||||
Sum(SumNode),
|
||||
Unary(UnaryNode),
|
||||
|
@ -139,6 +140,7 @@ macro_rules! match_all {
|
|||
Node::MaxPool2d(node) => $func(node),
|
||||
Node::Range(node) => $func(node),
|
||||
Node::Reshape(node) => $func(node),
|
||||
Node::Slice(node) => $func(node),
|
||||
Node::Squeeze(node) => $func(node),
|
||||
Node::Sum(node) => $func(node),
|
||||
Node::Unary(node) => $func(node),
|
||||
|
@ -186,6 +188,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::MaxPool2d(_) => "max_pool2d",
|
||||
Node::Range(_) => "range",
|
||||
Node::Reshape(_) => "reshape",
|
||||
Node::Slice(_) => "slice",
|
||||
Node::Squeeze(_) => "squeeze",
|
||||
Node::Sum(_) => "add",
|
||||
Node::Unary(unary) => unary.kind.as_str(),
|
||||
|
|
|
@ -27,6 +27,7 @@ pub(crate) mod random_normal;
|
|||
pub(crate) mod random_uniform;
|
||||
pub(crate) mod range;
|
||||
pub(crate) mod reshape;
|
||||
pub(crate) mod slice;
|
||||
pub(crate) mod squeeze;
|
||||
pub(crate) mod sum;
|
||||
pub(crate) mod unary;
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct SliceNode {
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub starts: Vec<usize>,
|
||||
pub ends: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
let starts = &self.starts;
|
||||
let ends = &self.ends;
|
||||
|
||||
quote! {
|
||||
let #output = #input.slice([#(#starts..#ends),*]);
|
||||
}
|
||||
}
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Slice(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::record::FullPrecisionSettings;
|
||||
|
||||
use super::*;
|
||||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{slice::SliceNode, test::assert_tokens},
|
||||
TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_slice() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
graph.register(SliceNode::new(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
vec![0, 0, 0, 0],
|
||||
vec![1, 1, 1, 1],
|
||||
));
|
||||
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);
|
||||
|
||||
tensor2
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -63,6 +63,7 @@ pub fn dim_inference(node: &mut Node) {
|
|||
NodeType::Sigmoid => same_as_input(node),
|
||||
NodeType::Sign => same_as_input(node),
|
||||
NodeType::Sin => same_as_input(node),
|
||||
NodeType::Slice => slice_update_outputs(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Sub => same_as_input(node),
|
||||
|
@ -423,6 +424,33 @@ fn squeeze_update_output(node: &mut Node) {
|
|||
});
|
||||
}
|
||||
|
||||
fn slice_update_outputs(node: &mut Node) {
|
||||
let shape = match &node.inputs[1].value {
|
||||
Some(value) => match value {
|
||||
Data::Int64s(shape) => Some(shape.clone()),
|
||||
_ => panic!("Slice: invalid input types"),
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
if shape.is_none() {
|
||||
panic!("Slice: invalid shape");
|
||||
}
|
||||
|
||||
let output = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Slice: invalid output types"),
|
||||
};
|
||||
|
||||
if let Some(shape) = shape {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
dim: shape.len(),
|
||||
shape: None, // shape is calculated at runtime
|
||||
..output
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||
fn unsqueeze_update_output(node: &mut Node) {
|
||||
let axes = if node.inputs.len() == 2 {
|
||||
|
|
|
@ -18,7 +18,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};
|
|||
|
||||
use protobuf::Message;
|
||||
|
||||
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
|
||||
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [
|
||||
NodeType::BatchNormalization,
|
||||
NodeType::Clip,
|
||||
NodeType::Conv1d,
|
||||
|
@ -28,6 +28,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
|
|||
NodeType::Reshape,
|
||||
NodeType::Unsqueeze,
|
||||
NodeType::ReduceSum,
|
||||
NodeType::Slice,
|
||||
NodeType::Squeeze,
|
||||
];
|
||||
|
||||
|
|
|
@ -1015,6 +1015,67 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
|
|||
(start_dim as usize, end_dim as usize)
|
||||
}
|
||||
|
||||
pub fn slice_config(node: &Node) -> (Vec<usize>, Vec<usize>) {
|
||||
let start_value = &node.inputs[1].value;
|
||||
let end_value = &node.inputs[2].value;
|
||||
|
||||
let starts = match &node.inputs[1].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
|
||||
if let Some(Data::Int64s(shape)) = start_value.as_ref() {
|
||||
shape
|
||||
.iter()
|
||||
.map(|x| {
|
||||
assert!(*x >= 0, "Slice: start must be positive");
|
||||
*x as usize
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
panic!("Tensor data type must be int64")
|
||||
}
|
||||
}
|
||||
_ => panic!("Only tensor input is valid for shape"),
|
||||
};
|
||||
|
||||
let ends = match &node.inputs[2].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D");
|
||||
if let Some(Data::Int64s(shape)) = end_value.as_ref() {
|
||||
shape
|
||||
.iter()
|
||||
.map(|x| {
|
||||
assert!(*x >= 0, "Slice: end must be positive");
|
||||
*x as usize
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
panic!("Tensor data type must be int64")
|
||||
}
|
||||
}
|
||||
_ => panic!("Only tensor input is valid for shape"),
|
||||
};
|
||||
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"axes" => {
|
||||
let mut i = 0;
|
||||
value.clone().into_i64s().iter().for_each(|x| {
|
||||
assert_eq!(*x, i, "Slice: axes must be consecutive");
|
||||
i += 1;
|
||||
})
|
||||
}
|
||||
"steps" => value.clone().into_i64s().into_iter().for_each(|x| {
|
||||
if x != 1 {
|
||||
panic!("Slice: steps other than 1 are not supported");
|
||||
}
|
||||
}),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
(starts, ends)
|
||||
}
|
||||
|
||||
pub fn transpose_config(curr: &Node) -> Vec<i64> {
|
||||
if curr.inputs.len() != 1 {
|
||||
panic!(
|
||||
|
|
|
@ -42,6 +42,7 @@ use crate::{
|
|||
random_uniform::RandomUniformNode,
|
||||
range::RangeNode,
|
||||
reshape::ReshapeNode,
|
||||
slice::SliceNode,
|
||||
squeeze::SqueezeNode,
|
||||
sum::SumNode,
|
||||
unary::UnaryNode,
|
||||
|
@ -294,6 +295,7 @@ impl OnnxGraph {
|
|||
NodeType::Shape => graph.register(Self::shape_conversion(node)),
|
||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||
NodeType::Sin => graph.register(Self::sin_conversion(node)),
|
||||
NodeType::Slice => graph.register(Self::slice_conversion(node)),
|
||||
NodeType::Sum => graph.register(Self::sum_conversion(node)),
|
||||
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
||||
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
||||
|
@ -686,6 +688,14 @@ impl OnnxGraph {
|
|||
UnaryNode::sin(input, output)
|
||||
}
|
||||
|
||||
fn slice_conversion(node: Node) -> SliceNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let (starts, ends) = slice_config(&node);
|
||||
|
||||
SliceNode::new(input, output, starts, ends)
|
||||
}
|
||||
|
||||
fn sum_conversion(node: Node) -> SumNode {
|
||||
let inputs = node
|
||||
.inputs
|
||||
|
|
Loading…
Reference in New Issue