From fed618b265699ba3ba360062b0fb1d5972444e28 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 29 Nov 2023 14:57:52 -0600 Subject: [PATCH] Add Cos, Exp, Gelu, Log and Neg unary ONNX ops (#1013) * Add ONNX Cos OP * Add ONNX Exp OP * Add ONNX Gelu OP * Add ONNX Log OP * Allow approx_constant clippy rule in generated model code * Add ONNX Neg OP * Fix tests with custom imports for unary nodes * Add scalar tests for Sqrt --- burn-import/SUPPORTED-ONNX-OPS.md | 10 +- burn-import/onnx-tests/build.rs | 5 + burn-import/onnx-tests/tests/cos/cos.onnx | 16 ++ burn-import/onnx-tests/tests/cos/cos.py | 40 ++++ burn-import/onnx-tests/tests/exp/exp.onnx | 16 ++ burn-import/onnx-tests/tests/exp/exp.py | 41 ++++ burn-import/onnx-tests/tests/gelu/gelu.onnx | Bin 0 -> 698 bytes burn-import/onnx-tests/tests/gelu/gelu.py | 44 +++++ burn-import/onnx-tests/tests/log/log.onnx | 16 ++ burn-import/onnx-tests/tests/log/log.py | 40 ++++ burn-import/onnx-tests/tests/neg/neg.onnx | Bin 0 -> 201 bytes burn-import/onnx-tests/tests/neg/neg.py | 42 ++++ burn-import/onnx-tests/tests/onnx_tests.rs | 81 +++++++- burn-import/onnx-tests/tests/sqrt/sqrt.onnx | Bin 137 -> 267 bytes burn-import/onnx-tests/tests/sqrt/sqrt.py | 21 +- burn-import/src/burn/graph.rs | 2 +- burn-import/src/burn/node/avg_pool2d.rs | 2 +- burn-import/src/burn/node/base.rs | 22 ++- burn-import/src/burn/node/batch_norm.rs | 2 +- burn-import/src/burn/node/binary.rs | 2 +- burn-import/src/burn/node/clip.rs | 6 +- burn-import/src/burn/node/concat.rs | 2 +- burn-import/src/burn/node/conv1d.rs | 2 +- burn-import/src/burn/node/conv2d.rs | 2 +- burn-import/src/burn/node/dropout.rs | 2 +- burn-import/src/burn/node/gather.rs | 2 +- burn-import/src/burn/node/global_avg_pool.rs | 4 +- burn-import/src/burn/node/linear.rs | 2 +- burn-import/src/burn/node/matmul.rs | 2 +- burn-import/src/burn/node/max_pool2d.rs | 2 +- burn-import/src/burn/node/reshape.rs | 2 +- burn-import/src/burn/node/unary.rs | 194 +++++++++++++++++-- burn-import/src/burn/ty.rs | 4 +- burn-import/src/onnx/dim_inference.rs | 59 +++--- burn-import/src/onnx/to_burn.rs | 39 ++++ 35 files changed, 637 insertions(+), 89 deletions(-) create mode 100644 burn-import/onnx-tests/tests/cos/cos.onnx create mode 100755 burn-import/onnx-tests/tests/cos/cos.py create mode 100644 burn-import/onnx-tests/tests/exp/exp.onnx create mode 100755 burn-import/onnx-tests/tests/exp/exp.py create mode 100644 burn-import/onnx-tests/tests/gelu/gelu.onnx create mode 100755 burn-import/onnx-tests/tests/gelu/gelu.py create mode 100644 burn-import/onnx-tests/tests/log/log.onnx create mode 100755 burn-import/onnx-tests/tests/log/log.py create mode 100644 burn-import/onnx-tests/tests/neg/neg.onnx create mode 100755 burn-import/onnx-tests/tests/neg/neg.py mode change 100644 => 100755 burn-import/onnx-tests/tests/sqrt/sqrt.py diff --git a/burn-import/SUPPORTED-ONNX-OPS.md b/burn-import/SUPPORTED-ONNX-OPS.md index 86cd2fcbc..89b234f1c 100644 --- a/burn-import/SUPPORTED-ONNX-OPS.md +++ b/burn-import/SUPPORTED-ONNX-OPS.md @@ -43,7 +43,7 @@ represent the corresponding Burn Op. | [Conv2d][34] | ✅ | ✅ | | [ConvInteger][37] | ❌ | ❌ | | [ConvTranspose][38] | ❌ | ✅ | -| [Cos][39] | ❌ | ✅ | +| [Cos][39] | ✅ | ✅ | | [Cosh][40] | ❌ | ❌ | | [CumSum][41] | ❌ | ❌ | | [DepthToSpace][42] | ❌ | ❌ | @@ -57,7 +57,7 @@ represent the corresponding Burn Op. | [Elu][50] | ❌ | ❌ | | [Equal][51] | ✅ | ✅ | | [Erf][52] | ✅ | ✅ | -| [Exp][53] | ❌ | ✅ | +| [Exp][53] | ✅ | ✅ | | [Expand][54] | ❌ | ❌ | | [EyeLike][55] | ❌ | ❌ | | [Flatten][56] | ✅ | ✅ | @@ -65,7 +65,7 @@ represent the corresponding Burn Op. | [Gather][58] | ✅ | ✅ | | [GatherElements][59] | ❌ | ❌ | | [GatherND][60] | ❌ | ❌ | -| [Gelu][61] | ❌ | ✅ | +| [Gelu][61] | ✅ | ✅ | | [Gemm][62] | ❌ | ❌ | | [GlobalAveragePool][63] | ✅ | ✅ | | [GlobalLpPool][64] | ❌ | ❌ | @@ -91,7 +91,7 @@ represent the corresponding Burn Op. | [Less][84] | ❌ | ✅ | | [LessOrEqual][85] | ❌ | ✅ | | Linear | ✅ | ✅ | -| [Log][87] | ❌ | ✅ | +| [Log][87] | ✅ | ✅ | | [LogSoftmax][88] | ✅ | ✅ | | [Loop][89] | ❌ | ❌ | | [LpNormalization][90] | ❌ | ❌ | @@ -113,7 +113,7 @@ represent the corresponding Burn Op. | [Mod][106] | ❌ | ❌ | | [Mul][107] | ❌ | ✅ | | [Multinomial][108] | ❌ | ❌ | -| [Neg][109] | ❌ | ✅ | +| [Neg][109] | ✅ | ✅ | | [NegativeLogLikelihoodLoss][110] | ❌ | ❌ | | [NonMaxSuppression][112] | ❌ | ❌ | | [NonZero][113] | ❌ | ❌ | diff --git a/burn-import/onnx-tests/build.rs b/burn-import/onnx-tests/build.rs index 92ab4ac80..0a3db38e5 100644 --- a/burn-import/onnx-tests/build.rs +++ b/burn-import/onnx-tests/build.rs @@ -15,18 +15,23 @@ fn main() { .input("tests/concat/concat.onnx") .input("tests/conv1d/conv1d.onnx") .input("tests/conv2d/conv2d.onnx") + .input("tests/cos/cos.onnx") .input("tests/div/div.onnx") .input("tests/dropout/dropout_opset16.onnx") .input("tests/dropout/dropout_opset7.onnx") .input("tests/equal/equal.onnx") .input("tests/erf/erf.onnx") + .input("tests/exp/exp.onnx") .input("tests/flatten/flatten.onnx") .input("tests/gather/gather.onnx") + .input("tests/gelu/gelu.onnx") .input("tests/global_avr_pool/global_avr_pool.onnx") .input("tests/linear/linear.onnx") .input("tests/log_softmax/log_softmax.onnx") + .input("tests/log/log.onnx") .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/mul/mul.onnx") + .input("tests/neg/neg.onnx") .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/reshape/reshape.onnx") diff --git a/burn-import/onnx-tests/tests/cos/cos.onnx b/burn-import/onnx-tests/tests/cos/cos.onnx new file mode 100644 index 000000000..be21e3066 --- /dev/null +++ b/burn-import/onnx-tests/tests/cos/cos.onnx @@ -0,0 +1,16 @@ +pytorch2.1.0:m + + onnx::Cos_01/Cos"Cos +main_graphZ% + onnx::Cos_0 + + + + +b +1 + + + + +B \ No newline at end of file diff --git a/burn-import/onnx-tests/tests/cos/cos.py b/burn-import/onnx-tests/tests/cos/cos.py new file mode 100755 index 000000000..9d81461fa --- /dev/null +++ b/burn-import/onnx-tests/tests/cos/cos.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/cos/cos.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + return torch.cos(x) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "cos.onnx" + test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device) + + torch.onnx.export(model, (test_input), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + print("Test input data: {}".format(test_input)) + output = model.forward(test_input) + print("Test output data: {}".format(output)) + + +if __name__ == '__main__': + main() diff --git a/burn-import/onnx-tests/tests/exp/exp.onnx b/burn-import/onnx-tests/tests/exp/exp.onnx new file mode 100644 index 000000000..d66c39e67 --- /dev/null +++ b/burn-import/onnx-tests/tests/exp/exp.onnx @@ -0,0 +1,16 @@ +pytorch2.1.0:m + + onnx::Exp_01/Exp"Exp +main_graphZ% + onnx::Exp_0 + + + + +b +1 + + + + +B \ No newline at end of file diff --git a/burn-import/onnx-tests/tests/exp/exp.py b/burn-import/onnx-tests/tests/exp/exp.py new file mode 100755 index 000000000..81a96dca1 --- /dev/null +++ b/burn-import/onnx-tests/tests/exp/exp.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/exp/exp.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + return torch.exp(x) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + import math + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "exp.onnx" + test_input = torch.tensor([[[[0, math.log(2.)]]]], device=device) + + torch.onnx.export(model, (test_input), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + print("Test input data: {}".format(test_input)) + output = model.forward(test_input) + print("Test output data: {}".format(output)) + + +if __name__ == '__main__': + main() diff --git a/burn-import/onnx-tests/tests/gelu/gelu.onnx b/burn-import/onnx-tests/tests/gelu/gelu.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ca9f694c7c285020a7ed6b7dfd78d3b0c7b32022 GIT binary patch literal 698 zcmd&S zse;?Z2^Ujh21;;QaA8pe)9G52hM^NEro;>+xO{~qQQaMHh)^tmE``%m3=9omPX)sq zf@~=lZYyAJbWBM>SPS(XP)vy#NN~BLh7{DR$XZ~=_?G5i7y}ejVg?deylR9{jP6w< zJYHo0d({QzRb)#syo#_wh|xld8xo)(9fn{>3325nX6D7G7bO;CM5&_qQ%H9kYv0eSXziHH!(9WKD{WhAR|f@S*d}L7?%hKqmTd>69*#_W{DC-R%R%~ d#>K+HEyR!{&BbVhOHUFv7o#yur4y5Y006EkB*p*$ literal 0 HcmV?d00001 diff --git a/burn-import/onnx-tests/tests/neg/neg.py b/burn-import/onnx-tests/tests/neg/neg.py new file mode 100755 index 000000000..7dae1b807 --- /dev/null +++ b/burn-import/onnx-tests/tests/neg/neg.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/neg/neg.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + return torch.neg(x), -y + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "neg.onnx" + test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device) + test_input2 = 99.0 + + torch.onnx.export(model, (test_input1, test_input2), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + print("Test input1: {}, input2: {}".format(test_input1, test_input2)) + output1, output2 = model.forward(test_input1, test_input2) + print("Test output1 data: {}".format(output1)) + print("Test output2 data: {}".format(output2)) + + +if __name__ == '__main__': + main() diff --git a/burn-import/onnx-tests/tests/onnx_tests.rs b/burn-import/onnx-tests/tests/onnx_tests.rs index 3eac50590..d02756b99 100644 --- a/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/burn-import/onnx-tests/tests/onnx_tests.rs @@ -22,18 +22,23 @@ include_models!( concat, conv1d, conv2d, + cos, div, dropout_opset16, dropout_opset7, equal, erf, + exp, flatten, gather, + gelu, global_avr_pool, linear, log_softmax, + log, maxpool2d, mul, + neg, recip, relu, reshape, @@ -331,12 +336,15 @@ mod tests { fn sqrt() { let model: sqrt::Model = sqrt::Model::new(); - let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]); + let input1 = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]); + let input2 = 36f64; - let output = model.forward(input); - let expected = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]); + let (output1, output2) = model.forward(input1, input2); + let expected1 = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]); + let expected2 = 6.0; - assert_eq!(output.to_data(), expected); + assert_eq!(output1.to_data(), expected1); + assert_eq!(output2, expected2); } #[test] @@ -641,4 +649,69 @@ mod tests { let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]); output.to_data().assert_approx_eq(&expected, 4); } + + #[test] + fn cos() { + let model: cos::Model = cos::Model::new(); + + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]); + + let output = model.forward(input); + let expected = Data::from([[[[0.5403, -0.6536, -0.9111, 0.9912]]]]); + + output.to_data().assert_approx_eq(&expected, 4); + } + + #[test] + #[allow(clippy::approx_constant)] + fn exp() { + let model: exp::Model = exp::Model::new(); + + let input = Tensor::::from_floats([[[[0.0000, 0.6931]]]]); + + let output = model.forward(input); + let expected = Data::from([[[[1., 2.]]]]); + + output.to_data().assert_approx_eq(&expected, 2); + } + + #[test] + fn gelu() { + let model: gelu::Model = gelu::Model::new(); + + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]); + + let output = model.forward(input); + let expected = Data::from([[[[0.8413, 3.9999, 9.0000, 25.0000]]]]); + + output.to_data().assert_approx_eq(&expected, 4); + } + + #[test] + fn log() { + let model: log::Model = log::Model::new(); + + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]); + + let output = model.forward(input); + let expected = Data::from([[[[0.0000, 1.3863, 2.1972, 3.2189]]]]); + + output.to_data().assert_approx_eq(&expected, 4); + } + + #[test] + fn neg() { + let model: neg::Model = neg::Model::new(); + + let input1 = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]); + let input2 = 99f64; + + let (output1, output2) = model.forward(input1, input2); + let expected1 = Data::from([[[[-1.0, -4.0, -9.0, -25.0]]]]); + let expected2 = -99f64; + + output1.to_data().assert_approx_eq(&expected1, 4); + + assert_eq!(output2, expected2); + } } diff --git a/burn-import/onnx-tests/tests/sqrt/sqrt.onnx b/burn-import/onnx-tests/tests/sqrt/sqrt.onnx index 27753ebcbf81c97487d4321490fcad287493e514..99931914a4b26cbdd3edae806b0cb0544e091c5e 100644 GIT binary patch literal 267 zcmdQvWyZympO;r*W#yb$ToP|6#HSBtEdl$OLB zNU?(XN-Q8zElw_`l6(p71&mBwaxi_tg+)MhLX5^hr69Hv3y9=W;zBb?h|xrf9jqeW z5TZtiD>pGSFFw5}u^=N#4b40;E)fn!AptHX4n`!*5+w#x2X!wS7Yhfs5JQqQ7o)Kd S4oykiT#P0{P?b(h0s;W}Z#nz` delta 97 zcmeBX>SWa8;1FUjs4U4ZO3sjCHPSQGGqft?lH=mZ&&#W@vI;INDv37`Vl BurnGraph { // otherwise let_and_return error will be triggered by clippy. // For now, we just disable the warning. quote! { - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, #input_def) -> #output_type_def { #body diff --git a/burn-import/src/burn/node/avg_pool2d.rs b/burn-import/src/burn/node/avg_pool2d.rs index f7be50f1e..6cecbef6f 100644 --- a/burn-import/src/burn/node/avg_pool2d.rs +++ b/burn-import/src/burn/node/avg_pool2d.rs @@ -147,7 +147,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.avg_pool2d.forward(input); diff --git a/burn-import/src/burn/node/base.rs b/burn-import/src/burn/node/base.rs index 944156306..abc92d5f2 100644 --- a/burn-import/src/burn/node/base.rs +++ b/burn-import/src/burn/node/base.rs @@ -197,7 +197,7 @@ pub(crate) mod tests { use crate::burn::{ graph::BurnGraph, node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen}, - TensorType, + BurnImports, TensorType, }; use burn::{ nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data, @@ -205,7 +205,8 @@ pub(crate) mod tests { use proc_macro2::TokenStream; use quote::quote; - pub(crate) fn one_node_graph + 'static>( + #[track_caller] + pub(crate) fn one_node_graph + Clone + 'static>( node_gen: T, forward: TokenStream, input_names: Vec, @@ -213,15 +214,16 @@ pub(crate) mod tests { ) { let mut graph = BurnGraph::::default(); - graph.register(node_gen); + graph.register(node_gen.clone()); graph.register_input_output(input_names, output_names); + let mut imports = BurnImports::default(); + node_gen.register_imports(&mut imports); + let imports = imports.codegen(); + let expected = quote! { - use burn::{ - module::Module, - tensor::{backend::Backend, Tensor}, - }; + #imports #[derive(Module, Debug)] pub struct Model { @@ -236,7 +238,7 @@ pub(crate) mod tests { } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] #forward } }; @@ -298,7 +300,7 @@ pub(crate) mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = tensor1.matmul(tensor2); let tensor4 = self.conv2d.forward(tensor3); @@ -370,7 +372,7 @@ pub(crate) mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = tensor1.matmul(tensor2.clone()); let tensor4 = self.conv2d.forward(tensor2); diff --git a/burn-import/src/burn/node/batch_norm.rs b/burn-import/src/burn/node/batch_norm.rs index b706c47cc..7c20677a1 100644 --- a/burn-import/src/burn/node/batch_norm.rs +++ b/burn-import/src/burn/node/batch_norm.rs @@ -210,7 +210,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.norm.forward(input); diff --git a/burn-import/src/burn/node/binary.rs b/burn-import/src/burn/node/binary.rs index 2e237339a..a8e0b3cc1 100644 --- a/burn-import/src/burn/node/binary.rs +++ b/burn-import/src/burn/node/binary.rs @@ -325,7 +325,7 @@ mod tests { } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = tensor1.equal(tensor2); diff --git a/burn-import/src/burn/node/clip.rs b/burn-import/src/burn/node/clip.rs index 3156ab6d7..e4b41cca2 100644 --- a/burn-import/src/burn/node/clip.rs +++ b/burn-import/src/burn/node/clip.rs @@ -87,7 +87,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor) -> Tensor { let tensor2 = tensor1.clamp(0f64, 1f64); @@ -130,7 +130,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor) -> Tensor { let tensor2 = tensor1.clamp_min(0f64); @@ -173,7 +173,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor) -> Tensor { let tensor2 = tensor1.clamp_max(1f64); diff --git a/burn-import/src/burn/node/concat.rs b/burn-import/src/burn/node/concat.rs index a0cb6e189..3b8d4c821 100644 --- a/burn-import/src/burn/node/concat.rs +++ b/burn-import/src/burn/node/concat.rs @@ -92,7 +92,7 @@ mod tests { } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1); diff --git a/burn-import/src/burn/node/conv1d.rs b/burn-import/src/burn/node/conv1d.rs index 0c3e56888..566190ac5 100644 --- a/burn-import/src/burn/node/conv1d.rs +++ b/burn-import/src/burn/node/conv1d.rs @@ -188,7 +188,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.conv1d.forward(input); diff --git a/burn-import/src/burn/node/conv2d.rs b/burn-import/src/burn/node/conv2d.rs index 9b3c9f440..305a0fc09 100644 --- a/burn-import/src/burn/node/conv2d.rs +++ b/burn-import/src/burn/node/conv2d.rs @@ -187,7 +187,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.conv2d.forward(input); diff --git a/burn-import/src/burn/node/dropout.rs b/burn-import/src/burn/node/dropout.rs index de61653e1..bcce0c05a 100644 --- a/burn-import/src/burn/node/dropout.rs +++ b/burn-import/src/burn/node/dropout.rs @@ -131,7 +131,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.dropout.forward(input); diff --git a/burn-import/src/burn/node/gather.rs b/burn-import/src/burn/node/gather.rs index 2c0e6bc9e..13101d67a 100644 --- a/burn-import/src/burn/node/gather.rs +++ b/burn-import/src/burn/node/gather.rs @@ -92,7 +92,7 @@ mod tests { } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = tensor1.gather(1, tensor2); diff --git a/burn-import/src/burn/node/global_avg_pool.rs b/burn-import/src/burn/node/global_avg_pool.rs index 80d6f3cec..392cc08c4 100644 --- a/burn-import/src/burn/node/global_avg_pool.rs +++ b/burn-import/src/burn/node/global_avg_pool.rs @@ -152,7 +152,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.global_avg_pool1.forward(input); @@ -201,7 +201,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.global_avg_pool1.forward(input); diff --git a/burn-import/src/burn/node/linear.rs b/burn-import/src/burn/node/linear.rs index b413c2c4a..4806841f2 100644 --- a/burn-import/src/burn/node/linear.rs +++ b/burn-import/src/burn/node/linear.rs @@ -164,7 +164,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.linear.forward(input); diff --git a/burn-import/src/burn/node/matmul.rs b/burn-import/src/burn/node/matmul.rs index b7b1eea97..4bb71e6b3 100644 --- a/burn-import/src/burn/node/matmul.rs +++ b/burn-import/src/burn/node/matmul.rs @@ -84,7 +84,7 @@ mod tests { } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { let tensor3 = tensor1.matmul(tensor2); diff --git a/burn-import/src/burn/node/max_pool2d.rs b/burn-import/src/burn/node/max_pool2d.rs index 2bbf859bb..7ce309547 100644 --- a/burn-import/src/burn/node/max_pool2d.rs +++ b/burn-import/src/burn/node/max_pool2d.rs @@ -148,7 +148,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { let output = self.max_pool2d.forward(input); diff --git a/burn-import/src/burn/node/reshape.rs b/burn-import/src/burn/node/reshape.rs index df8959e90..0b4c308d4 100644 --- a/burn-import/src/burn/node/reshape.rs +++ b/burn-import/src/burn/node/reshape.rs @@ -76,7 +76,7 @@ mod tests { phantom: core::marker::PhantomData, } } - #[allow(clippy::let_and_return)] + #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, tensor1: Tensor) -> Tensor { let tensor2 = tensor1.reshape([4, 4, 4, 4]); diff --git a/burn-import/src/burn/node/unary.rs b/burn-import/src/burn/node/unary.rs index 47d2300b3..74098dc7f 100644 --- a/burn-import/src/burn/node/unary.rs +++ b/burn-import/src/burn/node/unary.rs @@ -1,5 +1,5 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, ToTokens, Type}; +use crate::burn::{BurnImports, Scope, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -21,13 +21,18 @@ pub struct UnaryNode { #[derive(Clone)] pub enum UnaryNodeKind { Cast, + Cos, Erf, + Exp, Flatten, + Gelu, + Log, LogSoftmax, - Softmax, - Relu, + Neg, Reciprocal, + Relu, Sigmoid, + Softmax, Sqrt, Tanh, Transpose, @@ -37,13 +42,18 @@ impl UnaryNodeKind { pub fn as_str(&self) -> &str { match self { Self::Cast => "cast", + Self::Cos => "cos", Self::Erf => "erf", + Self::Exp => "exp", Self::Flatten => "flatten", + Self::Gelu => "gelu", + Self::Log => "log", Self::LogSoftmax => "log_softmax", - Self::Softmax => "softmax", - Self::Relu => "relu", + Self::Neg => "neg", Self::Reciprocal => "reciprocal", + Self::Relu => "relu", Self::Sigmoid => "sigmoid", + Self::Softmax => "softmax", Self::Sqrt => "sqrt", Self::Tanh => "tanh", Self::Transpose => "transpose", @@ -97,6 +107,16 @@ impl NodeCodegen for UnaryNode { fn into_node(self) -> Node { Node::Unary(self) } + + fn register_imports(&self, imports: &mut BurnImports) { + // Register the imports depending on the kind of the node. + match self.kind { + UnaryNodeKind::Neg => { + imports.register("core::ops::Neg"); + } + _ => {} + } + } } impl UnaryNode { @@ -155,6 +175,31 @@ impl UnaryNode { Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function)) } + pub(crate) fn cos(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.cos()}; + Self::new(input, output, UnaryNodeKind::Cos, Rc::new(function)) + } + + pub(crate) fn exp(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.exp()}; + Self::new(input, output, UnaryNodeKind::Exp, Rc::new(function)) + } + + pub(crate) fn gelu(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.gelu()}; + Self::new(input, output, UnaryNodeKind::Gelu, Rc::new(function)) + } + + pub(crate) fn log(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.log()}; + Self::new(input, output, UnaryNodeKind::Log, Rc::new(function)) + } + + pub(crate) fn neg(input: Type, output: Type) -> Self { + let function = move |input| quote! { #input.neg()}; + Self::new(input, output, UnaryNodeKind::Neg, Rc::new(function)) + } + /// Casts the input to the output type. /// /// Currently this function only supports the following conversions: @@ -166,12 +211,23 @@ impl UnaryNode { /// 4) tensor -> scalar /// 5) scalar -> tensor pub(crate) fn cast(input: Type, output: Type) -> Self { - let function = match output.clone() { - Type::Scalar(scalar) => { - let ty = scalar.ty(); - move |input| quote! { #input as #ty } + match (input.clone(), output.clone()) { + (Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => { + if input_scalar.kind == output_scalar.kind { + // If the input and output types are the same, we don't need to cast. + Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input)) + } else { + // If the input and output types are different, we need to cast. + let ty = output_scalar.ty(); + Self::new( + input, + output, + UnaryNodeKind::Cast, + Rc::new(move |input| quote! { #input as #ty }), + ) + } } - Type::Tensor(_tensor) => { + (Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => { // TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023) // TODO: If the input is scalar and the output type is a tensor, // we should generate another code block. (@antimora 8/4/2023) @@ -180,9 +236,7 @@ impl UnaryNode { } _ => panic!("output must be a tensor"), - }; - - Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function)) + } } } @@ -400,4 +454,118 @@ mod tests { vec!["scalar2".to_string()], ); } + + #[test] + fn test_unary_codegen_cos() { + one_node_graph( + UnaryNode::cos( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.cos(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_exp() { + one_node_graph( + UnaryNode::exp( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.exp(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_gelu() { + one_node_graph( + UnaryNode::gelu( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.gelu(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_codegen_log() { + one_node_graph( + UnaryNode::log( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.log(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + + #[test] + fn test_unary_neg_scalar() { + one_node_graph( + UnaryNode::neg( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), + ), + quote! { + pub fn forward(&self, scalar1: f64) -> f64 { + let scalar2 = scalar1.neg(); + + scalar2 + } + }, + vec!["scalar1".to_string()], + vec!["scalar2".to_string()], + ); + } + + #[test] + fn test_unary_neg_tensor() { + one_node_graph( + UnaryNode::neg( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.neg(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } } diff --git a/burn-import/src/burn/ty.rs b/burn-import/src/burn/ty.rs index 292523ecc..f5fa41e77 100644 --- a/burn-import/src/burn/ty.rs +++ b/burn-import/src/burn/ty.rs @@ -20,7 +20,7 @@ pub enum TensorKind { Bool, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ScalarKind { Int32, Int64, @@ -29,7 +29,7 @@ pub enum ScalarKind { Bool, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ScalarType { pub name: Ident, pub kind: ScalarKind, diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index ce2a9ce4f..867c1dbb0 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -62,38 +62,43 @@ pub fn dim_inference( updater.update_tensor_inputs(node); match node.node_type { + NodeType::Add => same_as_input(node), + NodeType::AveragePool2d => same_as_input(node), + NodeType::BatchNormalization => same_as_input(node), + NodeType::Cast => cast_update_outputs(node), + NodeType::Clip => same_as_input(node), + NodeType::Concat => concat_update_outputs(node), + NodeType::Constant => constant_update_outputs(node), NodeType::Conv1d => conv1d_update_outputs(node), NodeType::Conv2d => conv2d_update_outputs(node), - NodeType::MaxPool2d => same_as_input(node), - NodeType::Linear => linear_update_outputs(node), - NodeType::Flatten => flatten_update_outputs(node), - NodeType::GatherElements => same_as_input(node), - NodeType::Relu => same_as_input(node), - NodeType::LogSoftmax => same_as_input(node), - NodeType::BatchNormalization => same_as_input(node), - NodeType::Add => same_as_input(node), - NodeType::Sub => same_as_input(node), - NodeType::Mul => same_as_input(node), - NodeType::Cast => cast_update_outputs(node), + NodeType::Cos => same_as_input(node), NodeType::Div => same_as_input(node), - NodeType::Erf => same_as_input(node), - NodeType::Sqrt => same_as_input(node), - NodeType::Tanh => same_as_input(node), - NodeType::Reciprocal => same_as_input(node), - NodeType::Softmax => same_as_input(node), - NodeType::ReduceMean => mean_update_outputs(node), - NodeType::Constant => constant_update_outputs(node), - NodeType::Equal => equal_update_outputs(node), - NodeType::Shape => shape_update_outputs(node), - NodeType::Unsqueeze => unsqueeze_update_outputs(node), - NodeType::Sigmoid => same_as_input(node), - NodeType::Transpose => same_as_input(node), - NodeType::Concat => concat_update_outputs(node), - NodeType::Reshape => reshape_update_outputs(node), NodeType::Dropout => same_as_input(node), + NodeType::Equal => equal_update_outputs(node), + NodeType::Erf => same_as_input(node), + NodeType::Exp => same_as_input(node), + NodeType::Flatten => flatten_update_outputs(node), + NodeType::Gelu => same_as_input(node), + NodeType::GatherElements => same_as_input(node), NodeType::GlobalAveragePool => same_as_input(node), - NodeType::AveragePool2d => same_as_input(node), - NodeType::Clip => same_as_input(node), + NodeType::Linear => linear_update_outputs(node), + NodeType::Log => same_as_input(node), + NodeType::LogSoftmax => same_as_input(node), + NodeType::MaxPool2d => same_as_input(node), + NodeType::Mul => same_as_input(node), + NodeType::Neg => same_as_input(node), + NodeType::Reciprocal => same_as_input(node), + NodeType::ReduceMean => mean_update_outputs(node), + NodeType::Relu => same_as_input(node), + NodeType::Reshape => reshape_update_outputs(node), + NodeType::Shape => shape_update_outputs(node), + NodeType::Sigmoid => same_as_input(node), + NodeType::Softmax => same_as_input(node), + NodeType::Sqrt => same_as_input(node), + NodeType::Sub => same_as_input(node), + NodeType::Tanh => same_as_input(node), + NodeType::Transpose => same_as_input(node), + NodeType::Unsqueeze => unsqueeze_update_outputs(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), } diff --git a/burn-import/src/onnx/to_burn.rs b/burn-import/src/onnx/to_burn.rs index 8c1337a5d..66b6181aa 100644 --- a/burn-import/src/onnx/to_burn.rs +++ b/burn-import/src/onnx/to_burn.rs @@ -232,19 +232,24 @@ impl ONNXGraph { NodeType::Div => graph.register(Self::div_conversion(node)), NodeType::Equal => graph.register(Self::equal_conversion(node)), NodeType::Erf => graph.register(Self::erf_conversion(node)), + NodeType::Exp => graph.register(Self::exp_conversion(node)), NodeType::Clip => graph.register(Self::clip_conversion(node)), + NodeType::Cos => graph.register(Self::cos_conversion(node)), NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), NodeType::Conv2d => graph.register(Self::conv2d_conversion::(node)), NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), NodeType::MatMul => graph.register(Self::matmul_conversion(node)), + NodeType::Neg => graph.register(Self::neg_conversion(node)), NodeType::Linear => graph.register(Self::linear_conversion::(node)), NodeType::BatchNormalization => { graph.register(Self::batch_norm_conversion::(node)) } NodeType::Relu => graph.register(Self::relu_conversion(node)), + NodeType::Gelu => graph.register(Self::gelu_conversion(node)), NodeType::Flatten => graph.register(Self::flatten_conversion(node)), NodeType::GatherElements => graph.register(Self::gather_conversion(node)), + NodeType::Log => graph.register(Self::log_conversion(node)), NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), NodeType::Softmax => graph.register(Self::softmax_conversion(node)), NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)), @@ -395,6 +400,20 @@ impl ONNXGraph { UnaryNode::relu(input, output) } + fn gelu_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::gelu(input, output) + } + + fn log_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::log(input, output) + } + fn flatten_conversion(node: Node) -> UnaryNode { let input = node.inputs.get(0).unwrap().to_type(); let output = node.outputs.get(0).unwrap().to_type(); @@ -607,6 +626,26 @@ impl ONNXGraph { GlobalAvgPoolNode::new(name, input, output) } + + fn cos_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::cos(input, output) + } + + fn exp_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + + UnaryNode::exp(input, output) + } + + fn neg_conversion(node: Node) -> UnaryNode { + let input = node.inputs.get(0).unwrap().to_type(); + let output = node.outputs.get(0).unwrap().to_type(); + UnaryNode::neg(input, output) + } } /// Extract data from node states and convert it to `DataSerialize`.