mirror of https://github.com/tracel-ai/burn.git
Add Cos, Exp, Gelu, Log and Neg unary ONNX ops (#1013)
* Add ONNX Cos OP * Add ONNX Exp OP * Add ONNX Gelu OP * Add ONNX Log OP * Allow approx_constant clippy rule in generated model code * Add ONNX Neg OP * Fix tests with custom imports for unary nodes * Add scalar tests for Sqrt
This commit is contained in:
parent
87393b2070
commit
fed618b265
|
@ -43,7 +43,7 @@ represent the corresponding Burn Op.
|
|||
| [Conv2d][34] | ✅ | ✅ |
|
||||
| [ConvInteger][37] | ❌ | ❌ |
|
||||
| [ConvTranspose][38] | ❌ | ✅ |
|
||||
| [Cos][39] | ❌ | ✅ |
|
||||
| [Cos][39] | ✅ | ✅ |
|
||||
| [Cosh][40] | ❌ | ❌ |
|
||||
| [CumSum][41] | ❌ | ❌ |
|
||||
| [DepthToSpace][42] | ❌ | ❌ |
|
||||
|
@ -57,7 +57,7 @@ represent the corresponding Burn Op.
|
|||
| [Elu][50] | ❌ | ❌ |
|
||||
| [Equal][51] | ✅ | ✅ |
|
||||
| [Erf][52] | ✅ | ✅ |
|
||||
| [Exp][53] | ❌ | ✅ |
|
||||
| [Exp][53] | ✅ | ✅ |
|
||||
| [Expand][54] | ❌ | ❌ |
|
||||
| [EyeLike][55] | ❌ | ❌ |
|
||||
| [Flatten][56] | ✅ | ✅ |
|
||||
|
@ -65,7 +65,7 @@ represent the corresponding Burn Op.
|
|||
| [Gather][58] | ✅ | ✅ |
|
||||
| [GatherElements][59] | ❌ | ❌ |
|
||||
| [GatherND][60] | ❌ | ❌ |
|
||||
| [Gelu][61] | ❌ | ✅ |
|
||||
| [Gelu][61] | ✅ | ✅ |
|
||||
| [Gemm][62] | ❌ | ❌ |
|
||||
| [GlobalAveragePool][63] | ✅ | ✅ |
|
||||
| [GlobalLpPool][64] | ❌ | ❌ |
|
||||
|
@ -91,7 +91,7 @@ represent the corresponding Burn Op.
|
|||
| [Less][84] | ❌ | ✅ |
|
||||
| [LessOrEqual][85] | ❌ | ✅ |
|
||||
| Linear | ✅ | ✅ |
|
||||
| [Log][87] | ❌ | ✅ |
|
||||
| [Log][87] | ✅ | ✅ |
|
||||
| [LogSoftmax][88] | ✅ | ✅ |
|
||||
| [Loop][89] | ❌ | ❌ |
|
||||
| [LpNormalization][90] | ❌ | ❌ |
|
||||
|
@ -113,7 +113,7 @@ represent the corresponding Burn Op.
|
|||
| [Mod][106] | ❌ | ❌ |
|
||||
| [Mul][107] | ❌ | ✅ |
|
||||
| [Multinomial][108] | ❌ | ❌ |
|
||||
| [Neg][109] | ❌ | ✅ |
|
||||
| [Neg][109] | ✅ | ✅ |
|
||||
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
|
||||
| [NonMaxSuppression][112] | ❌ | ❌ |
|
||||
| [NonZero][113] | ❌ | ❌ |
|
||||
|
|
|
@ -15,18 +15,23 @@ fn main() {
|
|||
.input("tests/concat/concat.onnx")
|
||||
.input("tests/conv1d/conv1d.onnx")
|
||||
.input("tests/conv2d/conv2d.onnx")
|
||||
.input("tests/cos/cos.onnx")
|
||||
.input("tests/div/div.onnx")
|
||||
.input("tests/dropout/dropout_opset16.onnx")
|
||||
.input("tests/dropout/dropout_opset7.onnx")
|
||||
.input("tests/equal/equal.onnx")
|
||||
.input("tests/erf/erf.onnx")
|
||||
.input("tests/exp/exp.onnx")
|
||||
.input("tests/flatten/flatten.onnx")
|
||||
.input("tests/gather/gather.onnx")
|
||||
.input("tests/gelu/gelu.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.input("tests/linear/linear.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/log/log.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/reshape/reshape.onnx")
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
pytorch2.1.0:m
|
||||
|
||||
onnx::Cos_01/Cos"Cos
|
||||
main_graphZ%
|
||||
onnx::Cos_0
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/cos/cos.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "cos.onnx"
|
||||
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
|
||||
|
||||
torch.onnx.export(model, (test_input), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,16 @@
|
|||
pytorch2.1.0:m
|
||||
|
||||
onnx::Exp_01/Exp"Exp
|
||||
main_graphZ%
|
||||
onnx::Exp_0
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/exp/exp.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.exp(x)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
import math
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "exp.onnx"
|
||||
test_input = torch.tensor([[[[0, math.log(2.)]]]], device=device)
|
||||
|
||||
torch.onnx.export(model, (test_input), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/gelu/gelu.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.onnx
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.gelu(x)
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "gelu.onnx"
|
||||
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
|
||||
|
||||
torch.onnx.export(model, (test_input), onnx_name,
|
||||
verbose=False,
|
||||
opset_version=16,
|
||||
# opset_version=20, TODO: uncomment this when PyTorch supports it
|
||||
# Note: OPSET 20 is required for GELU to be exported otherwise
|
||||
# op is broken down into multiple ops
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,16 @@
|
|||
pytorch2.1.0:m
|
||||
|
||||
onnx::Log_01/Log"Log
|
||||
main_graphZ%
|
||||
onnx::Log_0
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/log/log.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.log(x)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "log.onnx"
|
||||
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
|
||||
|
||||
torch.onnx.export(model, (test_input), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/neg/neg.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.neg(x), -y
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "neg.onnx"
|
||||
test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
|
||||
test_input2 = 99.0
|
||||
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input1: {}, input2: {}".format(test_input1, test_input2))
|
||||
output1, output2 = model.forward(test_input1, test_input2)
|
||||
print("Test output1 data: {}".format(output1))
|
||||
print("Test output2 data: {}".format(output2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -22,18 +22,23 @@ include_models!(
|
|||
concat,
|
||||
conv1d,
|
||||
conv2d,
|
||||
cos,
|
||||
div,
|
||||
dropout_opset16,
|
||||
dropout_opset7,
|
||||
equal,
|
||||
erf,
|
||||
exp,
|
||||
flatten,
|
||||
gather,
|
||||
gelu,
|
||||
global_avr_pool,
|
||||
linear,
|
||||
log_softmax,
|
||||
log,
|
||||
maxpool2d,
|
||||
mul,
|
||||
neg,
|
||||
recip,
|
||||
relu,
|
||||
reshape,
|
||||
|
@ -331,12 +336,15 @@ mod tests {
|
|||
fn sqrt() {
|
||||
let model: sqrt::Model<Backend> = sqrt::Model::new();
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
|
||||
let input1 = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
|
||||
let input2 = 36f64;
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]);
|
||||
let (output1, output2) = model.forward(input1, input2);
|
||||
let expected1 = Data::from([[[[1.0, 2.0, 3.0, 5.0]]]]);
|
||||
let expected2 = 6.0;
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
assert_eq!(output1.to_data(), expected1);
|
||||
assert_eq!(output2, expected2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -641,4 +649,69 @@ mod tests {
|
|||
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos() {
|
||||
let model: cos::Model<Backend> = cos::Model::new();
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[[0.5403, -0.6536, -0.9111, 0.9912]]]]);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::approx_constant)]
|
||||
fn exp() {
|
||||
let model: exp::Model<Backend> = exp::Model::new();
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[0.0000, 0.6931]]]]);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[[1., 2.]]]]);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gelu() {
|
||||
let model: gelu::Model<Backend> = gelu::Model::new();
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[[0.8413, 3.9999, 9.0000, 25.0000]]]]);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log() {
|
||||
let model: log::Model<Backend> = log::Model::new();
|
||||
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([[[[0.0000, 1.3863, 2.1972, 3.2189]]]]);
|
||||
|
||||
output.to_data().assert_approx_eq(&expected, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg() {
|
||||
let model: neg::Model<Backend> = neg::Model::new();
|
||||
|
||||
let input1 = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]]);
|
||||
let input2 = 99f64;
|
||||
|
||||
let (output1, output2) = model.forward(input1, input2);
|
||||
let expected1 = Data::from([[[[-1.0, -4.0, -9.0, -25.0]]]]);
|
||||
let expected2 = -99f64;
|
||||
|
||||
output1.to_data().assert_approx_eq(&expected1, 4);
|
||||
|
||||
assert_eq!(output2, expected2);
|
||||
}
|
||||
}
|
||||
|
|
Binary file not shown.
|
@ -10,9 +10,9 @@ class Model(nn.Module):
|
|||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
def forward(self, x, y):
|
||||
y_tensor = torch.tensor(y) # Convert y to a PyTorch tensor
|
||||
return torch.sqrt(x), torch.sqrt(y_tensor)
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
|
@ -23,19 +23,20 @@ def main():
|
|||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "sqrt.onnx"
|
||||
dummy_input = torch.randn(1, 4, 9, 25, device=device)
|
||||
|
||||
torch.onnx.export(model, (dummy_input), onnx_name,
|
||||
test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])
|
||||
test_input2 = 36.0
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])
|
||||
test_input1 = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]])
|
||||
test_input2 = 36.0
|
||||
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
print("Test input data: {}, {}".format(test_input1, test_input2))
|
||||
output1, output2 = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}, {}".format(output1, output2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -542,7 +542,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
// otherwise let_and_return error will be triggered by clippy.
|
||||
// For now, we just disable the warning.
|
||||
quote! {
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, #input_def) -> #output_type_def {
|
||||
#body
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.avg_pool2d.forward(input);
|
||||
|
||||
|
|
|
@ -197,7 +197,7 @@ pub(crate) mod tests {
|
|||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen},
|
||||
TensorType,
|
||||
BurnImports, TensorType,
|
||||
};
|
||||
use burn::{
|
||||
nn::conv::Conv2dConfig, nn::PaddingConfig2d, record::FullPrecisionSettings, tensor::Data,
|
||||
|
@ -205,7 +205,8 @@ pub(crate) mod tests {
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
pub(crate) fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + 'static>(
|
||||
#[track_caller]
|
||||
pub(crate) fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + Clone + 'static>(
|
||||
node_gen: T,
|
||||
forward: TokenStream,
|
||||
input_names: Vec<String>,
|
||||
|
@ -213,15 +214,16 @@ pub(crate) mod tests {
|
|||
) {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(node_gen);
|
||||
graph.register(node_gen.clone());
|
||||
|
||||
graph.register_input_output(input_names, output_names);
|
||||
|
||||
let mut imports = BurnImports::default();
|
||||
node_gen.register_imports(&mut imports);
|
||||
let imports = imports.codegen();
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
#imports
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
|
@ -236,7 +238,7 @@ pub(crate) mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
#forward
|
||||
}
|
||||
};
|
||||
|
@ -298,7 +300,7 @@ pub(crate) mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.matmul(tensor2);
|
||||
let tensor4 = self.conv2d.forward(tensor3);
|
||||
|
@ -370,7 +372,7 @@ pub(crate) mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.matmul(tensor2.clone());
|
||||
let tensor4 = self.conv2d.forward(tensor2);
|
||||
|
|
|
@ -210,7 +210,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.norm.forward(input);
|
||||
|
||||
|
|
|
@ -325,7 +325,7 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4, Bool> {
|
||||
let tensor3 = tensor1.equal(tensor2);
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.clamp(0f64, 1f64);
|
||||
|
||||
|
@ -130,7 +130,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.clamp_min(0f64);
|
||||
|
||||
|
@ -173,7 +173,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.clamp_max(1f64);
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1);
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.conv1d.forward(input);
|
||||
|
||||
|
|
|
@ -187,7 +187,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.conv2d.forward(input);
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.dropout.forward(input);
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 2>, tensor2: Tensor<B, 2, Int>) -> Tensor<B, 2> {
|
||||
let tensor3 = tensor1.gather(1, tensor2);
|
||||
|
||||
|
|
|
@ -152,7 +152,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.global_avg_pool1.forward(input);
|
||||
|
||||
|
@ -201,7 +201,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let output = self.global_avg_pool1.forward(input);
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.linear.forward(input);
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.matmul(tensor2);
|
||||
|
||||
|
|
|
@ -148,7 +148,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output = self.max_pool2d.forward(input);
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ mod tests {
|
|||
phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.reshape([4, 4, 4, 4]);
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, ToTokens, Type};
|
||||
use crate::burn::{BurnImports, Scope, ToTokens, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
@ -21,13 +21,18 @@ pub struct UnaryNode {
|
|||
#[derive(Clone)]
|
||||
pub enum UnaryNodeKind {
|
||||
Cast,
|
||||
Cos,
|
||||
Erf,
|
||||
Exp,
|
||||
Flatten,
|
||||
Gelu,
|
||||
Log,
|
||||
LogSoftmax,
|
||||
Softmax,
|
||||
Relu,
|
||||
Neg,
|
||||
Reciprocal,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
Softmax,
|
||||
Sqrt,
|
||||
Tanh,
|
||||
Transpose,
|
||||
|
@ -37,13 +42,18 @@ impl UnaryNodeKind {
|
|||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::Cast => "cast",
|
||||
Self::Cos => "cos",
|
||||
Self::Erf => "erf",
|
||||
Self::Exp => "exp",
|
||||
Self::Flatten => "flatten",
|
||||
Self::Gelu => "gelu",
|
||||
Self::Log => "log",
|
||||
Self::LogSoftmax => "log_softmax",
|
||||
Self::Softmax => "softmax",
|
||||
Self::Relu => "relu",
|
||||
Self::Neg => "neg",
|
||||
Self::Reciprocal => "reciprocal",
|
||||
Self::Relu => "relu",
|
||||
Self::Sigmoid => "sigmoid",
|
||||
Self::Softmax => "softmax",
|
||||
Self::Sqrt => "sqrt",
|
||||
Self::Tanh => "tanh",
|
||||
Self::Transpose => "transpose",
|
||||
|
@ -97,6 +107,16 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
|
|||
fn into_node(self) -> Node<PS> {
|
||||
Node::Unary(self)
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
// Register the imports depending on the kind of the node.
|
||||
match self.kind {
|
||||
UnaryNodeKind::Neg => {
|
||||
imports.register("core::ops::Neg");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryNode {
|
||||
|
@ -155,6 +175,31 @@ impl UnaryNode {
|
|||
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn cos(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.cos()};
|
||||
Self::new(input, output, UnaryNodeKind::Cos, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn exp(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.exp()};
|
||||
Self::new(input, output, UnaryNodeKind::Exp, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn gelu(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.gelu()};
|
||||
Self::new(input, output, UnaryNodeKind::Gelu, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn log(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.log()};
|
||||
Self::new(input, output, UnaryNodeKind::Log, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn neg(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.neg()};
|
||||
Self::new(input, output, UnaryNodeKind::Neg, Rc::new(function))
|
||||
}
|
||||
|
||||
/// Casts the input to the output type.
|
||||
///
|
||||
/// Currently this function only supports the following conversions:
|
||||
|
@ -166,12 +211,23 @@ impl UnaryNode {
|
|||
/// 4) tensor -> scalar
|
||||
/// 5) scalar -> tensor
|
||||
pub(crate) fn cast(input: Type, output: Type) -> Self {
|
||||
let function = match output.clone() {
|
||||
Type::Scalar(scalar) => {
|
||||
let ty = scalar.ty();
|
||||
move |input| quote! { #input as #ty }
|
||||
match (input.clone(), output.clone()) {
|
||||
(Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => {
|
||||
if input_scalar.kind == output_scalar.kind {
|
||||
// If the input and output types are the same, we don't need to cast.
|
||||
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input))
|
||||
} else {
|
||||
// If the input and output types are different, we need to cast.
|
||||
let ty = output_scalar.ty();
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::Cast,
|
||||
Rc::new(move |input| quote! { #input as #ty }),
|
||||
)
|
||||
}
|
||||
}
|
||||
Type::Tensor(_tensor) => {
|
||||
(Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => {
|
||||
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
|
||||
// TODO: If the input is scalar and the output type is a tensor,
|
||||
// we should generate another code block. (@antimora 8/4/2023)
|
||||
|
@ -180,9 +236,7 @@ impl UnaryNode {
|
|||
}
|
||||
|
||||
_ => panic!("output must be a tensor"),
|
||||
};
|
||||
|
||||
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(function))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -400,4 +454,118 @@ mod tests {
|
|||
vec!["scalar2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_cos() {
|
||||
one_node_graph(
|
||||
UnaryNode::cos(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.cos();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_exp() {
|
||||
one_node_graph(
|
||||
UnaryNode::exp(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.exp();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_gelu() {
|
||||
one_node_graph(
|
||||
UnaryNode::gelu(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.gelu();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_log() {
|
||||
one_node_graph(
|
||||
UnaryNode::log(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.log();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_neg_scalar() {
|
||||
one_node_graph(
|
||||
UnaryNode::neg(
|
||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)),
|
||||
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, scalar1: f64) -> f64 {
|
||||
let scalar2 = scalar1.neg();
|
||||
|
||||
scalar2
|
||||
}
|
||||
},
|
||||
vec!["scalar1".to_string()],
|
||||
vec!["scalar2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_neg_tensor() {
|
||||
one_node_graph(
|
||||
UnaryNode::neg(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.neg();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ pub enum TensorKind {
|
|||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ScalarKind {
|
||||
Int32,
|
||||
Int64,
|
||||
|
@ -29,7 +29,7 @@ pub enum ScalarKind {
|
|||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ScalarType {
|
||||
pub name: Ident,
|
||||
pub kind: ScalarKind,
|
||||
|
|
|
@ -62,38 +62,43 @@ pub fn dim_inference(
|
|||
updater.update_tensor_inputs(node);
|
||||
|
||||
match node.node_type {
|
||||
NodeType::Add => same_as_input(node),
|
||||
NodeType::AveragePool2d => same_as_input(node),
|
||||
NodeType::BatchNormalization => same_as_input(node),
|
||||
NodeType::Cast => cast_update_outputs(node),
|
||||
NodeType::Clip => same_as_input(node),
|
||||
NodeType::Concat => concat_update_outputs(node),
|
||||
NodeType::Constant => constant_update_outputs(node),
|
||||
NodeType::Conv1d => conv1d_update_outputs(node),
|
||||
NodeType::Conv2d => conv2d_update_outputs(node),
|
||||
NodeType::MaxPool2d => same_as_input(node),
|
||||
NodeType::Linear => linear_update_outputs(node),
|
||||
NodeType::Flatten => flatten_update_outputs(node),
|
||||
NodeType::GatherElements => same_as_input(node),
|
||||
NodeType::Relu => same_as_input(node),
|
||||
NodeType::LogSoftmax => same_as_input(node),
|
||||
NodeType::BatchNormalization => same_as_input(node),
|
||||
NodeType::Add => same_as_input(node),
|
||||
NodeType::Sub => same_as_input(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Cast => cast_update_outputs(node),
|
||||
NodeType::Cos => same_as_input(node),
|
||||
NodeType::Div => same_as_input(node),
|
||||
NodeType::Erf => same_as_input(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Tanh => same_as_input(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::ReduceMean => mean_update_outputs(node),
|
||||
NodeType::Constant => constant_update_outputs(node),
|
||||
NodeType::Equal => equal_update_outputs(node),
|
||||
NodeType::Shape => shape_update_outputs(node),
|
||||
NodeType::Unsqueeze => unsqueeze_update_outputs(node),
|
||||
NodeType::Sigmoid => same_as_input(node),
|
||||
NodeType::Transpose => same_as_input(node),
|
||||
NodeType::Concat => concat_update_outputs(node),
|
||||
NodeType::Reshape => reshape_update_outputs(node),
|
||||
NodeType::Dropout => same_as_input(node),
|
||||
NodeType::Equal => equal_update_outputs(node),
|
||||
NodeType::Erf => same_as_input(node),
|
||||
NodeType::Exp => same_as_input(node),
|
||||
NodeType::Flatten => flatten_update_outputs(node),
|
||||
NodeType::Gelu => same_as_input(node),
|
||||
NodeType::GatherElements => same_as_input(node),
|
||||
NodeType::GlobalAveragePool => same_as_input(node),
|
||||
NodeType::AveragePool2d => same_as_input(node),
|
||||
NodeType::Clip => same_as_input(node),
|
||||
NodeType::Linear => linear_update_outputs(node),
|
||||
NodeType::Log => same_as_input(node),
|
||||
NodeType::LogSoftmax => same_as_input(node),
|
||||
NodeType::MaxPool2d => same_as_input(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Neg => same_as_input(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::ReduceMean => mean_update_outputs(node),
|
||||
NodeType::Relu => same_as_input(node),
|
||||
NodeType::Reshape => reshape_update_outputs(node),
|
||||
NodeType::Shape => shape_update_outputs(node),
|
||||
NodeType::Sigmoid => same_as_input(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Sub => same_as_input(node),
|
||||
NodeType::Tanh => same_as_input(node),
|
||||
NodeType::Transpose => same_as_input(node),
|
||||
NodeType::Unsqueeze => unsqueeze_update_outputs(node),
|
||||
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
|
||||
_ => temporary_pass_through_stub(node),
|
||||
}
|
||||
|
|
|
@ -232,19 +232,24 @@ impl ONNXGraph {
|
|||
NodeType::Div => graph.register(Self::div_conversion(node)),
|
||||
NodeType::Equal => graph.register(Self::equal_conversion(node)),
|
||||
NodeType::Erf => graph.register(Self::erf_conversion(node)),
|
||||
NodeType::Exp => graph.register(Self::exp_conversion(node)),
|
||||
NodeType::Clip => graph.register(Self::clip_conversion(node)),
|
||||
NodeType::Cos => graph.register(Self::cos_conversion(node)),
|
||||
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
|
||||
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
|
||||
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
|
||||
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
|
||||
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
|
||||
NodeType::Neg => graph.register(Self::neg_conversion(node)),
|
||||
NodeType::Linear => graph.register(Self::linear_conversion::<PS>(node)),
|
||||
NodeType::BatchNormalization => {
|
||||
graph.register(Self::batch_norm_conversion::<PS>(node))
|
||||
}
|
||||
NodeType::Relu => graph.register(Self::relu_conversion(node)),
|
||||
NodeType::Gelu => graph.register(Self::gelu_conversion(node)),
|
||||
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
|
||||
NodeType::GatherElements => graph.register(Self::gather_conversion(node)),
|
||||
NodeType::Log => graph.register(Self::log_conversion(node)),
|
||||
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
|
||||
NodeType::Softmax => graph.register(Self::softmax_conversion(node)),
|
||||
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
|
||||
|
@ -395,6 +400,20 @@ impl ONNXGraph {
|
|||
UnaryNode::relu(input, output)
|
||||
}
|
||||
|
||||
fn gelu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::gelu(input, output)
|
||||
}
|
||||
|
||||
fn log_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::log(input, output)
|
||||
}
|
||||
|
||||
fn flatten_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
@ -607,6 +626,26 @@ impl ONNXGraph {
|
|||
|
||||
GlobalAvgPoolNode::new(name, input, output)
|
||||
}
|
||||
|
||||
fn cos_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::cos(input, output)
|
||||
}
|
||||
|
||||
fn exp_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::exp(input, output)
|
||||
}
|
||||
|
||||
fn neg_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
UnaryNode::neg(input, output)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract data from node states and convert it to `DataSerialize`.
|
||||
|
|
Loading…
Reference in New Issue