mirror of https://github.com/tracel-ai/burn.git
ONNX Tile operation (#2092)
* renaming repeat to repeat_dim * implementing repeat function * renaming repeat files to repeat_dim * renaming part 2 * renaming part 3 * renaming part 4 * renaming part 5 * adding test file * adding unit test * adding rust book documentation * adding function args doc * fixing tests * changing repeat api to match pytorch equivalent * fixing clippy error * implementing tile onnx file * temp * working implementation and test * working e2e test * adding new supported onnx operation to the md file
This commit is contained in:
parent
6b61ad5a61
commit
d770b1f470
|
@ -191,7 +191,7 @@ represent the corresponding Burn Op.
|
|||
| [Tanh][182] | ✅ | ✅ |
|
||||
| [TfIdfVectorizer][183] | ❌ | ❌ |
|
||||
| [ThresholdedRelu][184] | ❌ | ❌ |
|
||||
| [Tile][185] | ❌ | ✅ |
|
||||
| [Tile][185] | ✅ | ✅ |
|
||||
| [TopK][186] | ❌ | ✅ |
|
||||
| [Transpose][187] | ✅ | ✅ |
|
||||
| [Trilu][188] | ❌ | ✅ |
|
||||
|
|
|
@ -93,6 +93,7 @@ fn main() {
|
|||
.input("tests/sum/sum.onnx")
|
||||
.input("tests/sum/sum_int.onnx")
|
||||
.input("tests/tanh/tanh.onnx")
|
||||
.input("tests/tile/tile.onnx")
|
||||
.input("tests/transpose/transpose.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
|
||||
|
|
|
@ -102,6 +102,7 @@ include_models!(
|
|||
sum,
|
||||
sum_int,
|
||||
tanh,
|
||||
tile,
|
||||
transpose,
|
||||
unsqueeze,
|
||||
unsqueeze_opset11,
|
||||
|
@ -1712,6 +1713,23 @@ mod tests {
|
|||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tile() {
|
||||
let device = Default::default();
|
||||
let model: tile::Model<Backend> = tile::Model::new(&device);
|
||||
|
||||
let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.]], &device);
|
||||
let output = model.forward(input).to_data();
|
||||
let expected = TensorData::from([
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
]);
|
||||
|
||||
output.assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsqueeze() {
|
||||
let device = Default::default();
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import onnx
|
||||
import onnx.helper
|
||||
import onnx.checker
|
||||
|
||||
|
||||
def build_model():
|
||||
# Define the input tensor as a graph input
|
||||
input_tensor = onnx.helper.make_tensor_value_info(
|
||||
name="input_tensor",
|
||||
elem_type=onnx.TensorProto.FLOAT,
|
||||
shape=[2, 2]
|
||||
)
|
||||
|
||||
output_tensor = onnx.helper.make_tensor_value_info(
|
||||
name="output_tensor",
|
||||
elem_type=onnx.TensorProto.FLOAT,
|
||||
shape=[4, 4]
|
||||
)
|
||||
|
||||
# Define the shape tensor for tiling as an initializer
|
||||
shape_tensor = onnx.helper.make_tensor(
|
||||
name="shape_tensor",
|
||||
data_type=onnx.TensorProto.INT64,
|
||||
dims=[2],
|
||||
vals=[2, 2]
|
||||
)
|
||||
# Create the Tile node
|
||||
tile_node = onnx.helper.make_node(
|
||||
"Tile",
|
||||
inputs=["input_tensor", "shape_tensor"],
|
||||
outputs=["output_tensor"]
|
||||
)
|
||||
|
||||
# Build the graph
|
||||
graph = onnx.helper.make_graph(
|
||||
nodes=[tile_node],
|
||||
name="main_graph",
|
||||
inputs=[input_tensor],
|
||||
outputs=[output_tensor],
|
||||
initializer=[shape_tensor]
|
||||
)
|
||||
|
||||
# Build the model
|
||||
model = onnx.helper.make_model(
|
||||
graph,
|
||||
ir_version=8,
|
||||
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
onnx_model = build_model()
|
||||
|
||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||
|
||||
file_name = "tile.onnx"
|
||||
onnx.save(onnx_model, file_name)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print(f"ONNX model saved as {file_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -11,7 +11,7 @@ use super::{
|
|||
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
|
||||
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
|
||||
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
|
||||
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
|
||||
};
|
||||
use crate::burn::{BurnImports, Scope, Type};
|
||||
use burn::backend::NdArray;
|
||||
|
@ -113,6 +113,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
Slice(SliceNode),
|
||||
Squeeze(SqueezeNode),
|
||||
Sum(SumNode),
|
||||
Tile(TileNode),
|
||||
Unary(UnaryNode),
|
||||
Unsqueeze(UnsqueezeNode),
|
||||
Where(WhereNode),
|
||||
|
@ -160,6 +161,7 @@ macro_rules! match_all {
|
|||
Node::Slice(node) => $func(node),
|
||||
Node::Squeeze(node) => $func(node),
|
||||
Node::Sum(node) => $func(node),
|
||||
Node::Tile(node) => $func(node),
|
||||
Node::Unary(node) => $func(node),
|
||||
Node::Unsqueeze(node) => $func(node),
|
||||
Node::Where(node) => $func(node),
|
||||
|
@ -215,6 +217,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::Slice(_) => "slice",
|
||||
Node::Squeeze(_) => "squeeze",
|
||||
Node::Sum(_) => "add",
|
||||
Node::Tile(_) => "tile",
|
||||
Node::Unary(unary) => unary.kind.as_str(),
|
||||
Node::Unsqueeze(_) => "unsqueeze",
|
||||
Node::Where(_) => "where",
|
||||
|
|
|
@ -36,6 +36,7 @@ pub(crate) mod resize;
|
|||
pub(crate) mod slice;
|
||||
pub(crate) mod squeeze;
|
||||
pub(crate) mod sum;
|
||||
pub(crate) mod tile;
|
||||
pub(crate) mod unary;
|
||||
pub(crate) mod unsqueeze;
|
||||
pub(crate) use base::*;
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, ToTokens, Type};
|
||||
use burn::config::Config;
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct TileConfig {
|
||||
pub repeats: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct TileNode {
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub config: TileConfig,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for TileNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
|
||||
let repeats = self.config.repeats.iter().map(|r| r.to_tokens());
|
||||
|
||||
quote! {
|
||||
let #output = #input.repeat(&[#(#repeats),*]);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Tile(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn::record::FullPrecisionSettings;
|
||||
|
||||
use super::*;
|
||||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{test::assert_tokens, tile::TileConfig, tile::TileNode},
|
||||
TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_tile() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
let config = TileConfig::new(vec![2, 3, 4]);
|
||||
graph.register(TileNode::new(
|
||||
TensorType::new_float("input", 3),
|
||||
TensorType::new_float("output", 3),
|
||||
config,
|
||||
));
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let output = input.repeat(&[2, 3, 4]);
|
||||
output
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ use burn::nn::{
|
|||
PaddingConfig2d, PaddingConfig3d,
|
||||
};
|
||||
|
||||
use crate::burn::node::pad::PadConfig;
|
||||
use crate::burn::node::{pad::PadConfig, tile::TileConfig};
|
||||
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
|
||||
|
||||
/// Create a Conv1dConfig from the attributes of the node
|
||||
|
@ -745,6 +745,26 @@ pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
|
|||
)
|
||||
}
|
||||
|
||||
/// Create a TileConfig from the attributes of the node
|
||||
pub fn tile_config(node: &Node) -> TileConfig {
|
||||
let repeat = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.map(|input| {
|
||||
if let Some(data) = &input.value {
|
||||
data.clone()
|
||||
.into_i64s()
|
||||
.iter()
|
||||
.map(|&x| x as usize)
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
TileConfig::new(repeat)
|
||||
}
|
||||
|
||||
/// Create a PadConfig from the attributes of the node
|
||||
pub fn pad_config(node: &Node) -> PadConfig {
|
||||
fn get_pads(node: &Node) -> Vec<usize> {
|
||||
|
|
|
@ -50,6 +50,7 @@ use crate::{
|
|||
slice::SliceNode,
|
||||
squeeze::SqueezeNode,
|
||||
sum::SumNode,
|
||||
tile::TileNode,
|
||||
unary::UnaryNode,
|
||||
unsqueeze::UnsqueezeNode,
|
||||
},
|
||||
|
@ -66,7 +67,8 @@ use super::op_configuration::{
|
|||
hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config,
|
||||
max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config,
|
||||
reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config,
|
||||
shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config,
|
||||
shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config,
|
||||
unsqueeze_config,
|
||||
};
|
||||
use onnx_ir::{
|
||||
convert_constant_value,
|
||||
|
@ -335,6 +337,7 @@ impl ParsedOnnxGraph {
|
|||
NodeType::Sign => graph.register(Self::sign_conversion(node)),
|
||||
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
|
||||
NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)),
|
||||
NodeType::Tile => graph.register(Self::tile_conversion(node)),
|
||||
NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)),
|
||||
NodeType::ConstantOfShape => {
|
||||
graph.register(Self::constant_of_shape_conversion(node))
|
||||
|
@ -1167,6 +1170,14 @@ impl ParsedOnnxGraph {
|
|||
|
||||
SqueezeNode::new(input, output, axes)
|
||||
}
|
||||
|
||||
fn tile_conversion(node: Node) -> TileNode {
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let config = tile_config(&node);
|
||||
|
||||
TileNode::new(input, output, config)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract data from node states and convert it to `TensorData`.
|
||||
|
|
Loading…
Reference in New Issue