Add Cos, Exp, Gelu, Log and Neg unary ONNX ops (#1013)

* Add ONNX Cos OP

* Add ONNX Exp OP

* Add ONNX Gelu OP

* Add ONNX Log OP

* Allow approx_constant clippy rule in generated model code

* Add ONNX Neg OP

* Fix tests with custom imports for unary nodes

* Add scalar tests for Sqrt
This commit is contained in:
Dilshod Tadjibaev 2023-11-29 14:57:52 -06:00 committed by GitHub
parent 87393b2070
commit fed618b265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 637 additions and 89 deletions

View File

@ -43,7 +43,7 @@ represent the corresponding Burn Op.
| [Conv2d][34] | ✅ | ✅ |
| [ConvInteger][37] | ❌ | ❌ |
| [ConvTranspose][38] | ❌ | ✅ |
| [Cos][39] | | ✅ |
| [Cos][39] | | ✅ |
| [Cosh][40] | ❌ | ❌ |
| [CumSum][41] | ❌ | ❌ |
| [DepthToSpace][42] | ❌ | ❌ |
@ -57,7 +57,7 @@ represent the corresponding Burn Op.
| [Elu][50] | ❌ | ❌ |
| [Equal][51] | ✅ | ✅ |
| [Erf][52] | ✅ | ✅ |
| [Exp][53] | | ✅ |
| [Exp][53] | | ✅ |
| [Expand][54] | ❌ | ❌ |
| [EyeLike][55] | ❌ | ❌ |
| [Flatten][56] | ✅ | ✅ |
@ -65,7 +65,7 @@ represent the corresponding Burn Op.
| [Gather][58] | ✅ | ✅ |
| [GatherElements][59] | ❌ | ❌ |
| [GatherND][60] | ❌ | ❌ |
| [Gelu][61] | | ✅ |
| [Gelu][61] | | ✅ |
| [Gemm][62] | ❌ | ❌ |
| [GlobalAveragePool][63] | ✅ | ✅ |
| [GlobalLpPool][64] | ❌ | ❌ |
@ -91,7 +91,7 @@ represent the corresponding Burn Op.
| [Less][84] | ❌ | ✅ |
| [LessOrEqual][85] | ❌ | ✅ |
| Linear | ✅ | ✅ |
| [Log][87] | | ✅ |
| [Log][87] | | ✅ |
| [LogSoftmax][88] | ✅ | ✅ |
| [Loop][89] | ❌ | ❌ |
| [LpNormalization][90] | ❌ | ❌ |
@ -113,7 +113,7 @@ represent the corresponding Burn Op.
| [Mod][106] | ❌ | ❌ |
| [Mul][107] | ❌ | ✅ |
| [Multinomial][108] | ❌ | ❌ |
| [Neg][109] | | ✅ |
| [Neg][109] | | ✅ |
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
| [NonMaxSuppression][112] | ❌ | ❌ |
| [NonZero][113] | ❌ | ❌ |

View File

@ -15,18 +15,23 @@ fn main() {
.input("tests/concat/concat.onnx")
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/cos/cos.onnx")
.input("tests/div/div.onnx")
.input("tests/dropout/dropout_opset16.onnx")
.input("tests/dropout/dropout_opset7.onnx")
.input("tests/equal/equal.onnx")
.input("tests/erf/erf.onnx")
.input("tests/exp/exp.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")

View File

@ -0,0 +1,16 @@
pytorch2.1.0:m

onnx::Cos_01/Cos"Cos
main_graphZ%
onnx::Cos_0




b
1




B

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/cos/cos.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.cos(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "cos.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
torch.onnx.export(model, (test_input), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,16 @@
pytorch2.1.0:m

onnx::Exp_01/Exp"Exp
main_graphZ%
onnx::Exp_0




b
1




B

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/exp/exp.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.exp(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
import math
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "exp.onnx"
test_input = torch.tensor([[[[0, math.log(2.)]]]], device=device)
torch.onnx.export(model, (test_input), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/gelu/gelu.onnx
import torch
import torch.nn as nn
import torch.onnx
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.gelu = nn.GELU()
def forward(self, x):
return self.gelu(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gelu.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
torch.onnx.export(model, (test_input), onnx_name,
verbose=False,
opset_version=16,
# opset_version=20, TODO: uncomment this when PyTorch supports it
# Note: OPSET 20 is required for GELU to be exported otherwise
# op is broken down into multiple ops
operator_export_type=torch.onnx.OperatorExportTypes.ONNX)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,16 @@
pytorch2.1.0:m

onnx::Log_01/Log"Log
main_graphZ%
onnx::Log_0




b
1




B

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/log/log.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.log(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "log.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
torch.onnx.export(model, (test_input), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/neg/neg.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, y):
return torch.neg(x), -y
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "neg.onnx"
test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
test_input2 = 99.0
torch.onnx.export(model, (test_input1, test_input2), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
print("Test input1: {}, input2: {}".format(test_input1, test_input2))
output1, output2 = model.forward(test_input1, test_input2)
print("Test output1 data: {}".format(output1))
print("Test output2 data: {}".format(output2))
if __name__ == '__main__':
main()

View File

@ -22,18 +22,23 @@ include_models!(
concat,
conv1d,
conv2d,
cos,
div,
dropout_opset16,
dropout_opset7,
equal,
erf,
exp,
flatten,
gather,
gelu,
global_avr_pool,
linear,
log_softmax,
log,
maxpool2d,
mul,
neg,
recip,
relu,
reshape,
@ -331,12 +336,15 @@ mod tests {
fn sqrt() {
let model: sqrt::Model<Backend> = sqrt::Model::new();
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
let input1 = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
let input2 = 36f64;
let output = model.forward(input);
let expected = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]);
let (output1, output2) = model.forward(input1, input2);
let expected1 = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]);
let expected2 = 6.0;
assert_eq!(output.to_data(), expected);
assert_eq!(output1.to_data(), expected1);
assert_eq!(output2, expected2);
}
#[test]
@ -641,4 +649,69 @@ mod tests {
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn cos() {
let model: cos::Model<Backend> = cos::Model::new();
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
let output = model.forward(input);
let expected = Data::from([[[[0.5403, -0.6536, -0.9111, 0.9912]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
#[allow(clippy::approx_constant)]
fn exp() {
let model: exp::Model<Backend> = exp::Model::new();
let input = Tensor::<Backend, 4>::from_floats([[[[0.0000, 0.6931]]]]);
let output = model.forward(input);
let expected = Data::from([[[[1., 2.]]]]);
output.to_data().assert_approx_eq(&expected, 2);
}
#[test]
fn gelu() {
let model: gelu::Model<Backend> = gelu::Model::new();
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
let output = model.forward(input);
let expected = Data::from([[[[0.8413, 3.9999, 9.0000, 25.0000]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn log() {
let model: log::Model<Backend> = log::Model::new();
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
let output = model.forward(input);
let expected = Data::from([[[[0.0000, 1.3863, 2.1972, 3.2189]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn neg() {
let model: neg::Model<Backend> = neg::Model::new();
let input1 = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
let input2 = 99f64;
let (output1, output2) = model.forward(input1, input2);
let expected1 = Data::from([[[[-1.0, -4.0, -9.0, -25.0]]]]);
let expected2 = -99f64;
output1.to_data().assert_approx_eq(&expected1, 4);
assert_eq!(output2, expected2);
}
}

21
burn-import/onnx-tests/tests/sqrt/sqrt.py Normal file → Executable file
View File

@ -10,9 +10,9 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.sqrt(x)
def forward(self, x, y):
y_tensor = torch.tensor(y) # Convert y to a PyTorch tensor
return torch.sqrt(x), torch.sqrt(y_tensor)
def main():
# Set random seed for reproducibility
@ -23,19 +23,20 @@ def main():
model.eval()
device = torch.device("cpu")
onnx_name = "sqrt.onnx"
dummy_input = torch.randn(1, 4, 9, 25, device=device)
torch.onnx.export(model, (dummy_input), onnx_name,
test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])
test_input2 = 36.0
torch.onnx.export(model, (test_input1, test_input2), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])
test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])
test_input2 = 36.0
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
print("Test input data: {}, {}".format(test_input1, test_input2))
output1, output2 = model.forward(test_input1, test_input2)
print("Test output data: {}, {}".format(output1, output2))
if __name__ == '__main__':

View File

@ -542,7 +542,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
// otherwise let_and_return error will be triggered by clippy.
// For now, we just disable the warning.
quote! {
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, #input_def) -> #output_type_def {
#body

View File

@ -147,7 +147,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.avg_pool2d.forward(input);

View File

@ -197,7 +197,7 @@ pub(crate) mod tests {
use crate::burn::{
graph::BurnGraph,
node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen},
TensorType,
BurnImports, TensorType,
};
use burn::{
nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data,
@ -205,7 +205,8 @@ pub(crate) mod tests {
use proc_macro2::TokenStream;
use quote::quote;
pub(crate) fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + 'static>(
#[track_caller]
pub(crate) fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + Clone + 'static>(
node_gen: T,
forward: TokenStream,
input_names: Vec<String>,
@ -213,15 +214,16 @@ pub(crate) mod tests {
) {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(node_gen);
graph.register(node_gen.clone());
graph.register_input_output(input_names, output_names);
let mut imports = BurnImports::default();
node_gen.register_imports(&mut imports);
let imports = imports.codegen();
let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#imports
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
@ -236,7 +238,7 @@ pub(crate) mod tests {
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
#forward
}
};
@ -298,7 +300,7 @@ pub(crate) mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2);
let tensor4 = self.conv2d.forward(tensor3);
@ -370,7 +372,7 @@ pub(crate) mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2.clone());
let tensor4 = self.conv2d.forward(tensor2);

View File

@ -210,7 +210,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.norm.forward(input);

View File

@ -325,7 +325,7 @@ mod tests {
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4, Bool> {
let tensor3 = tensor1.equal(tensor2);

View File

@ -87,7 +87,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.clamp(0f64, 1f64);
@ -130,7 +130,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.clamp_min(0f64);
@ -173,7 +173,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.clamp_max(1f64);

View File

@ -92,7 +92,7 @@ mod tests {
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1);

View File

@ -188,7 +188,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.conv1d.forward(input);

View File

@ -187,7 +187,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.conv2d.forward(input);

View File

@ -131,7 +131,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.dropout.forward(input);

View File

@ -92,7 +92,7 @@ mod tests {
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 2>, tensor2: Tensor<B, 2, Int>) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);

View File

@ -152,7 +152,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.global_avg_pool1.forward(input);
@ -201,7 +201,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output = self.global_avg_pool1.forward(input);

View File

@ -164,7 +164,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.linear.forward(input);

View File

@ -84,7 +84,7 @@ mod tests {
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2);

View File

@ -148,7 +148,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.max_pool2d.forward(input);

View File

@ -76,7 +76,7 @@ mod tests {
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return)]
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.reshape([4, 4, 4, 4]);

View File

@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, ToTokens, Type};
use crate::burn::{BurnImports, Scope, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
@ -21,13 +21,18 @@ pub struct UnaryNode {
#[derive(Clone)]
pub enum UnaryNodeKind {
Cast,
Cos,
Erf,
Exp,
Flatten,
Gelu,
Log,
LogSoftmax,
Softmax,
Relu,
Neg,
Reciprocal,
Relu,
Sigmoid,
Softmax,
Sqrt,
Tanh,
Transpose,
@ -37,13 +42,18 @@ impl UnaryNodeKind {
pub fn as_str(&self) -> &str {
match self {
Self::Cast => "cast",
Self::Cos => "cos",
Self::Erf => "erf",
Self::Exp => "exp",
Self::Flatten => "flatten",
Self::Gelu => "gelu",
Self::Log => "log",
Self::LogSoftmax => "log_softmax",
Self::Softmax => "softmax",
Self::Relu => "relu",
Self::Neg => "neg",
Self::Reciprocal => "reciprocal",
Self::Relu => "relu",
Self::Sigmoid => "sigmoid",
Self::Softmax => "softmax",
Self::Sqrt => "sqrt",
Self::Tanh => "tanh",
Self::Transpose => "transpose",
@ -97,6 +107,16 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
fn into_node(self) -> Node<PS> {
Node::Unary(self)
}
fn register_imports(&self, imports: &mut BurnImports) {
// Register the imports depending on the kind of the node.
match self.kind {
UnaryNodeKind::Neg => {
imports.register("core::ops::Neg");
}
_ => {}
}
}
}
impl UnaryNode {
@ -155,6 +175,31 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
}
pub(crate) fn cos(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.cos()};
Self::new(input, output, UnaryNodeKind::Cos, Rc::new(function))
}
pub(crate) fn exp(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.exp()};
Self::new(input, output, UnaryNodeKind::Exp, Rc::new(function))
}
pub(crate) fn gelu(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.gelu()};
Self::new(input, output, UnaryNodeKind::Gelu, Rc::new(function))
}
pub(crate) fn log(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.log()};
Self::new(input, output, UnaryNodeKind::Log, Rc::new(function))
}
pub(crate) fn neg(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.neg()};
Self::new(input, output, UnaryNodeKind::Neg, Rc::new(function))
}
/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
@ -166,12 +211,23 @@ impl UnaryNode {
/// 4) tensor -> scalar
/// 5) scalar -> tensor
pub(crate) fn cast(input: Type, output: Type) -> Self {
let function = match output.clone() {
Type::Scalar(scalar) => {
let ty = scalar.ty();
move |input| quote! { #input as #ty }
match (input.clone(), output.clone()) {
(Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => {
if input_scalar.kind == output_scalar.kind {
// If the input and output types are the same, we don't need to cast.
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input))
} else {
// If the input and output types are different, we need to cast.
let ty = output_scalar.ty();
Self::new(
input,
output,
UnaryNodeKind::Cast,
Rc::new(move |input| quote! { #input as #ty }),
)
}
}
Type::Tensor(_tensor) => {
(Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => {
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
// TODO: If the input is scalar and the output type is a tensor,
// we should generate another code block. (@antimora 8/4/2023)
@ -180,9 +236,7 @@ impl UnaryNode {
}
_ => panic!("output must be a tensor"),
};
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function))
}
}
}
@ -400,4 +454,118 @@ mod tests {
vec!["scalar2".to_string()],
);
}
#[test]
fn test_unary_codegen_cos() {
one_node_graph(
UnaryNode::cos(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.cos();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_exp() {
one_node_graph(
UnaryNode::exp(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.exp();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_gelu() {
one_node_graph(
UnaryNode::gelu(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.gelu();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_log() {
one_node_graph(
UnaryNode::log(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.log();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_neg_scalar() {
one_node_graph(
UnaryNode::neg(
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)),
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
),
quote! {
pub fn forward(&self, scalar1: f64) -> f64 {
let scalar2 = scalar1.neg();
scalar2
}
},
vec!["scalar1".to_string()],
vec!["scalar2".to_string()],
);
}
#[test]
fn test_unary_neg_tensor() {
one_node_graph(
UnaryNode::neg(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.neg();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
}

View File

@ -20,7 +20,7 @@ pub enum TensorKind {
Bool,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScalarKind {
Int32,
Int64,
@ -29,7 +29,7 @@ pub enum ScalarKind {
Bool,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScalarType {
pub name: Ident,
pub kind: ScalarKind,

View File

@ -62,38 +62,43 @@ pub fn dim_inference(
updater.update_tensor_inputs(node);
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
NodeType::Cast => cast_update_outputs(node),
NodeType::Clip => same_as_input(node),
NodeType::Concat => concat_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),
NodeType::Conv1d => conv1d_update_outputs(node),
NodeType::Conv2d => conv2d_update_outputs(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Flatten => flatten_update_outputs(node),
NodeType::GatherElements => same_as_input(node),
NodeType::Relu => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::BatchNormalization => same_as_input(node),
NodeType::Add => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Cast => cast_update_outputs(node),
NodeType::Cos => same_as_input(node),
NodeType::Div => same_as_input(node),
NodeType::Erf => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),
NodeType::Equal => equal_update_outputs(node),
NodeType::Shape => shape_update_outputs(node),
NodeType::Unsqueeze => unsqueeze_update_outputs(node),
NodeType::Sigmoid => same_as_input(node),
NodeType::Transpose => same_as_input(node),
NodeType::Concat => concat_update_outputs(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Dropout => same_as_input(node),
NodeType::Equal => equal_update_outputs(node),
NodeType::Erf => same_as_input(node),
NodeType::Exp => same_as_input(node),
NodeType::Flatten => flatten_update_outputs(node),
NodeType::Gelu => same_as_input(node),
NodeType::GatherElements => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
NodeType::Clip => same_as_input(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Shape => shape_update_outputs(node),
NodeType::Sigmoid => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
NodeType::Unsqueeze => unsqueeze_update_outputs(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
}

View File

@ -232,19 +232,24 @@ impl ONNXGraph {
NodeType::Div => graph.register(Self::div_conversion(node)),
NodeType::Equal => graph.register(Self::equal_conversion(node)),
NodeType::Erf => graph.register(Self::erf_conversion(node)),
NodeType::Exp => graph.register(Self::exp_conversion(node)),
NodeType::Clip => graph.register(Self::clip_conversion(node)),
NodeType::Cos => graph.register(Self::cos_conversion(node)),
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
NodeType::Linear => graph.register(Self::linear_conversion::<PS>(node)),
NodeType::BatchNormalization => {
graph.register(Self::batch_norm_conversion::<PS>(node))
}
NodeType::Relu => graph.register(Self::relu_conversion(node)),
NodeType::Gelu => graph.register(Self::gelu_conversion(node)),
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
NodeType::GatherElements => graph.register(Self::gather_conversion(node)),
NodeType::Log => graph.register(Self::log_conversion(node)),
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
NodeType::Softmax => graph.register(Self::softmax_conversion(node)),
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
@ -395,6 +400,20 @@ impl ONNXGraph {
UnaryNode::relu(input, output)
}
fn gelu_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::gelu(input, output)
}
fn log_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::log(input, output)
}
fn flatten_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
@ -607,6 +626,26 @@ impl ONNXGraph {
GlobalAvgPoolNode::new(name, input, output)
}
fn cos_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::cos(input, output)
}
fn exp_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::exp(input, output)
}
fn neg_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
UnaryNode::neg(input, output)
}
}
/// Extract data from node states and convert it to `DataSerialize`.