mirror of https://github.com/tracel-ai/burn.git
[ONNX] Add not op and extend cast support to tensors (#1634)
* Add not onnx op support * Extend cast onnx support to tensors * Fix clippy
This commit is contained in:
parent
7377bbe31c
commit
6d96e8d808
|
@ -126,9 +126,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.82"
|
||||
version = "1.0.81"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
|
||||
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
||||
|
||||
[[package]]
|
||||
name = "arboard"
|
||||
|
@ -3517,9 +3517,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.36"
|
||||
version = "1.0.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
|
|
@ -47,7 +47,7 @@ pretty_assertions = "1.4"
|
|||
proc-macro2 = "1.0.79"
|
||||
protobuf = "3.3"
|
||||
protobuf-codegen = "3.3"
|
||||
quote = "1.0.36"
|
||||
quote = "1.0.33"
|
||||
percent-encoding = "2.3.1"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = { version = "0.23.0" }
|
||||
|
|
|
@ -118,7 +118,7 @@ represent the corresponding Burn Op.
|
|||
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
|
||||
| [NonMaxSuppression][112] | ❌ | ❌ |
|
||||
| [NonZero][113] | ❌ | ❌ |
|
||||
| [Not][114] | ❌ | ✅ |
|
||||
| [Not][114] | ✅ | ✅ |
|
||||
| [OneHot][115] | ❌ | ✅ |
|
||||
| [Optional][116] | ❌ | ❌ |
|
||||
| [OptionalGetElement][117] | ❌ | ❌ |
|
||||
|
|
|
@ -10,6 +10,7 @@ fn main() {
|
|||
.input("tests/add/add.onnx")
|
||||
.input("tests/avg_pool2d/avg_pool2d.onnx")
|
||||
.input("tests/batch_norm/batch_norm.onnx")
|
||||
.input("tests/cast/cast.onnx")
|
||||
.input("tests/clip/clip_opset16.onnx")
|
||||
.input("tests/clip/clip_opset7.onnx")
|
||||
.input("tests/concat/concat.onnx")
|
||||
|
@ -32,6 +33,7 @@ fn main() {
|
|||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/not/not.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,64 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/cast/cast.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_bool,
|
||||
x_int,
|
||||
x_float,
|
||||
x_scalar,
|
||||
):
|
||||
# NOTE: we clone same-type casts for int and bool, otherwise the exporter would
|
||||
# link other type casts to the output of the bool cast, leading to additional casts
|
||||
return (
|
||||
x_bool.clone().bool(),
|
||||
x_bool.int(),
|
||||
x_bool.float(),
|
||||
x_int.bool(),
|
||||
x_int.clone().int(),
|
||||
x_int.float(),
|
||||
x_float.bool(),
|
||||
x_float.int(),
|
||||
x_float.float(),
|
||||
x_scalar.int(),
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "cast.onnx"
|
||||
test_bool = torch.ones((2, 1), device=device, dtype=torch.bool)
|
||||
test_int = torch.ones((2, 1), device=device, dtype=torch.int)
|
||||
test_float = torch.ones((2, 1), device=device, dtype=torch.float)
|
||||
test_scalar = torch.ones(1, device=device, dtype=torch.float).squeeze()
|
||||
test_input = (test_bool, test_int, test_float, test_scalar)
|
||||
|
||||
# NOTE: torch exports logical_not with a cast node even if the input is already bool
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207
|
||||
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print(f"Finished exporting model to {onnx_name}")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data: {test_input}")
|
||||
output = model.forward(*test_input)
|
||||
print(f"Test output data: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,19 @@
|
|||
pytorch2.1.2:©
|
||||
6
|
||||
onnx::Cast_0/Cast_output_0/Cast"Cast*
|
||||
to
|
||||
|
||||
/Cast_output_02/Not"Not
|
||||
main_graphZ&
|
||||
onnx::Cast_0
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
2
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/not/not.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.logical_not(x)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "not.onnx"
|
||||
test_input = torch.tensor([[[[True, False, True, False]]]], device=device)
|
||||
|
||||
# NOTE: torch exports logical_not with a cast node even if the input is already bool
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207
|
||||
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print(f"Finished exporting model to {onnx_name}")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data: {test_input}")
|
||||
output = model.forward(test_input)
|
||||
print(f"Test output data: {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -17,6 +17,7 @@ include_models!(
|
|||
add,
|
||||
avg_pool2d,
|
||||
batch_norm,
|
||||
cast,
|
||||
clip_opset16,
|
||||
clip_opset7,
|
||||
concat,
|
||||
|
@ -40,6 +41,7 @@ include_models!(
|
|||
maxpool2d,
|
||||
mul,
|
||||
neg,
|
||||
not,
|
||||
recip,
|
||||
reduce_mean,
|
||||
relu,
|
||||
|
@ -64,7 +66,7 @@ mod tests {
|
|||
|
||||
use super::*;
|
||||
|
||||
use burn::tensor::{Data, Int, Shape, Tensor};
|
||||
use burn::tensor::{Bool, Data, Int, Shape, Tensor};
|
||||
|
||||
use float_cmp::ApproxEq;
|
||||
|
||||
|
@ -854,6 +856,22 @@ mod tests {
|
|||
assert_eq!(output2, expected2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not() {
|
||||
let device = Default::default();
|
||||
let model: not::Model<Backend> = not::Model::new(&device);
|
||||
|
||||
let input = Tensor::<Backend, 4, Bool>::from_bool(
|
||||
Data::from([[[[true, false, true, false]]]]),
|
||||
&device,
|
||||
);
|
||||
|
||||
let output = model.forward(input).to_data();
|
||||
let expected = Data::from([[[[false, true, false, true]]]]);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_creation_with_a_default_device() {
|
||||
let device = Default::default();
|
||||
|
@ -908,4 +926,52 @@ mod tests {
|
|||
let output = model.forward(input);
|
||||
assert_eq!(output.shape(), expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast() {
|
||||
let device = Default::default();
|
||||
let model: cast::Model<Backend> = cast::Model::new(&device);
|
||||
|
||||
let input_bool =
|
||||
Tensor::<Backend, 2, Bool>::from_bool(Data::from([[true], [true]]), &device);
|
||||
let input_int = Tensor::<Backend, 2, Int>::from_ints([[1], [1]], &device);
|
||||
let input_float = Tensor::<Backend, 2>::from_floats([[1.], [1.]], &device);
|
||||
let input_scalar = 1f32;
|
||||
|
||||
let (
|
||||
output1,
|
||||
output2,
|
||||
output3,
|
||||
output4,
|
||||
output5,
|
||||
output6,
|
||||
output7,
|
||||
output8,
|
||||
output9,
|
||||
output_scalar,
|
||||
) = model.forward(
|
||||
input_bool.clone(),
|
||||
input_int.clone(),
|
||||
input_float.clone(),
|
||||
input_scalar,
|
||||
);
|
||||
let expected_bool = input_bool.to_data();
|
||||
let expected_int = input_int.to_data();
|
||||
let expected_float = input_float.to_data();
|
||||
let expected_scalar = 1;
|
||||
|
||||
assert_eq!(output1.to_data(), expected_bool);
|
||||
assert_eq!(output2.to_data(), expected_int);
|
||||
output3.to_data().assert_approx_eq(&expected_float, 4);
|
||||
|
||||
assert_eq!(output4.to_data(), expected_bool);
|
||||
assert_eq!(output5.to_data(), expected_int);
|
||||
output6.to_data().assert_approx_eq(&expected_float, 4);
|
||||
|
||||
assert_eq!(output7.to_data(), expected_bool);
|
||||
assert_eq!(output8.to_data(), expected_int);
|
||||
output9.to_data().assert_approx_eq(&expected_float, 4);
|
||||
|
||||
assert_eq!(output_scalar, expected_scalar);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{BurnImports, Scope, ToTokens, Type};
|
||||
use crate::burn::{BurnImports, Scope, TensorKind, ToTokens, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
@ -20,7 +20,8 @@ pub struct UnaryNode {
|
|||
/// Type of unary node.
|
||||
#[derive(Clone)]
|
||||
pub enum UnaryNodeKind {
|
||||
Cast,
|
||||
// Input and output tensor types (required for codegen imports)
|
||||
Cast(Option<TensorKind>, Option<TensorKind>),
|
||||
Cos,
|
||||
Erf,
|
||||
Exp,
|
||||
|
@ -29,6 +30,7 @@ pub enum UnaryNodeKind {
|
|||
Log,
|
||||
LogSoftmax,
|
||||
Neg,
|
||||
Not,
|
||||
ReduceMean,
|
||||
Reciprocal,
|
||||
LeakyRelu,
|
||||
|
@ -44,7 +46,7 @@ pub enum UnaryNodeKind {
|
|||
impl UnaryNodeKind {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::Cast => "cast",
|
||||
Self::Cast(..) => "cast",
|
||||
Self::Cos => "cos",
|
||||
Self::Erf => "erf",
|
||||
Self::Exp => "exp",
|
||||
|
@ -53,6 +55,7 @@ impl UnaryNodeKind {
|
|||
Self::Log => "log",
|
||||
Self::LogSoftmax => "log_softmax",
|
||||
Self::Neg => "neg",
|
||||
Self::Not => "not",
|
||||
Self::ReduceMean => "reduce_mean",
|
||||
Self::Reciprocal => "reciprocal",
|
||||
Self::LeakyRelu => "leaky_relu",
|
||||
|
@ -120,6 +123,17 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
|
|||
UnaryNodeKind::Neg => {
|
||||
imports.register("core::ops::Neg");
|
||||
}
|
||||
UnaryNodeKind::Not => {
|
||||
imports.register("burn::tensor::Bool");
|
||||
}
|
||||
UnaryNodeKind::Cast(Some(input_kind), Some(output_kind)) => {
|
||||
if input_kind == TensorKind::Bool || output_kind == TensorKind::Bool {
|
||||
imports.register("burn::tensor::Bool");
|
||||
}
|
||||
if input_kind == TensorKind::Int || output_kind == TensorKind::Int {
|
||||
imports.register("burn::tensor::Int");
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
@ -217,42 +231,61 @@ impl UnaryNode {
|
|||
Self::new(input, output, UnaryNodeKind::Neg, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn not(input: Type, output: Type) -> Self {
|
||||
// Not ONNX operator is constrained to bool tensors, so no need to check the type.
|
||||
let function = move |input| quote! { #input.bool_not()};
|
||||
Self::new(input, output, UnaryNodeKind::Not, Rc::new(function))
|
||||
}
|
||||
|
||||
/// Casts the input to the output type.
|
||||
///
|
||||
/// Currently this function only supports the following conversions:
|
||||
/// 1) scalar -> scalar
|
||||
///
|
||||
/// TODO: Implement the following conversions:
|
||||
/// 2) tensor int -> tensor float
|
||||
/// 3) tensor float -> tensor int
|
||||
/// 4) tensor -> scalar
|
||||
/// 5) scalar -> tensor
|
||||
pub(crate) fn cast(input: Type, output: Type) -> Self {
|
||||
match (input.clone(), output.clone()) {
|
||||
(Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => {
|
||||
if input_scalar.kind == output_scalar.kind {
|
||||
// If the input and output types are the same, we don't need to cast.
|
||||
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input))
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::Cast(None, None),
|
||||
Rc::new(|input| input),
|
||||
)
|
||||
} else {
|
||||
// If the input and output types are different, we need to cast.
|
||||
let ty = output_scalar.ty();
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::Cast,
|
||||
UnaryNodeKind::Cast(None, None),
|
||||
Rc::new(move |input| quote! { #input as #ty }),
|
||||
)
|
||||
}
|
||||
}
|
||||
(Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => {
|
||||
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
|
||||
// TODO: If the input is scalar and the output type is a tensor,
|
||||
// we should generate another code block. (@antimora 8/4/2023)
|
||||
// Tensor::from_data(Data::from([#input]).convert()).unsqueeze();
|
||||
todo!()
|
||||
}
|
||||
(Type::Tensor(input_tensor), Type::Tensor(output_tensor)) => {
|
||||
if input_tensor.kind == output_tensor.kind {
|
||||
// If the input and output types are the same, we don't need to cast.
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)),
|
||||
Rc::new(|input| input),
|
||||
)
|
||||
} else {
|
||||
// If the input and output types are different, we need to cast.
|
||||
let function = match output_tensor.kind {
|
||||
TensorKind::Bool => move |input| quote! { #input.bool()},
|
||||
TensorKind::Int => move |input| quote! { #input.int()},
|
||||
TensorKind::Float => move |input| quote! { #input.float()},
|
||||
};
|
||||
|
||||
_ => panic!("output must be a tensor"),
|
||||
Self::new(
|
||||
input,
|
||||
output,
|
||||
UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)),
|
||||
Rc::new(function),
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => panic!("output must be a tensor or scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -553,6 +586,51 @@ mod tests {
|
|||
vec!["scalar1".to_string()],
|
||||
vec!["scalar2".to_string()],
|
||||
);
|
||||
one_node_graph(
|
||||
UnaryNode::cast(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_int("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4, Int> {
|
||||
let tensor2 = tensor1.int();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
one_node_graph(
|
||||
UnaryNode::cast(
|
||||
Type::Tensor(TensorType::new_int("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4, Int>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.float();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
one_node_graph(
|
||||
UnaryNode::cast(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_bool("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4, Bool> {
|
||||
let tensor2 = tensor1.bool();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -687,4 +765,23 @@ mod tests {
|
|||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_not() {
|
||||
one_node_graph(
|
||||
UnaryNode::not(
|
||||
Type::Tensor(TensorType::new_bool("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_bool("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4, Bool>) -> Tensor<B, 4, Bool> {
|
||||
let tensor2 = tensor1.bool_not();
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ pub struct TensorType {
|
|||
pub shape: Option<Vec<usize>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TensorKind {
|
||||
Int,
|
||||
Float,
|
||||
|
|
|
@ -38,6 +38,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::MaxPool2d => same_as_input(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Neg => same_as_input(node),
|
||||
NodeType::Not => same_as_input(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::ReduceMean => reduce_mean_update_outputs(node),
|
||||
NodeType::Relu => same_as_input(node),
|
||||
|
@ -135,6 +136,7 @@ fn cast_update_outputs(node: &mut Node) {
|
|||
if node.inputs.len() != 1 {
|
||||
panic!("Cast: multiple inputs are not supported");
|
||||
}
|
||||
let input = &mut node.inputs[0];
|
||||
let output = &mut node.outputs[0];
|
||||
|
||||
// Extract cast type and update the output tensor
|
||||
|
@ -145,6 +147,7 @@ fn cast_update_outputs(node: &mut Node) {
|
|||
DataType::INT32 => ElementType::Int32,
|
||||
DataType::INT64 => ElementType::Int64,
|
||||
DataType::DOUBLE => ElementType::Float64,
|
||||
DataType::BOOL => ElementType::Bool,
|
||||
_ => panic!("Cast: unsupported type"),
|
||||
},
|
||||
_ => panic!("'to' attribute must be an Int64"),
|
||||
|
@ -152,19 +155,25 @@ fn cast_update_outputs(node: &mut Node) {
|
|||
None => panic!("Constant node must have a value attribute"),
|
||||
};
|
||||
|
||||
match output.ty.clone() {
|
||||
match input.ty.clone() {
|
||||
ArgType::Tensor(tensor) => {
|
||||
if tensor.dim == 0 {
|
||||
// treat 0-dim tensor as scalar
|
||||
output.ty = ArgType::Scalar(elem_type);
|
||||
input.ty = ArgType::Scalar(tensor.elem_type);
|
||||
} else {
|
||||
todo!("Cast: support casting from different tensor types");
|
||||
// Cast input and output are the same shape, but possibly different types
|
||||
output.ty = ArgType::Tensor(TensorType {
|
||||
elem_type,
|
||||
dim: tensor.dim,
|
||||
shape: tensor.shape.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
ArgType::Scalar(_scalar) => {
|
||||
output.ty = ArgType::Scalar(elem_type);
|
||||
}
|
||||
_ => panic!("Cast: only scalar input is valid"),
|
||||
_ => panic!("Cast: only scalar and tensor inputs are valid"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -237,6 +237,7 @@ impl OnnxGraph {
|
|||
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
|
||||
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
|
||||
NodeType::Neg => graph.register(Self::neg_conversion(node)),
|
||||
NodeType::Not => graph.register(Self::not_conversion(node)),
|
||||
NodeType::Linear => graph.register(Self::linear_conversion::<PS>(node)),
|
||||
NodeType::BatchNormalization => {
|
||||
graph.register(Self::batch_norm_conversion::<PS>(node))
|
||||
|
@ -697,6 +698,13 @@ impl OnnxGraph {
|
|||
let output = node.outputs.first().unwrap().to_type();
|
||||
UnaryNode::neg(input, output)
|
||||
}
|
||||
|
||||
fn not_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
UnaryNode::not(input, output)
|
||||
}
|
||||
|
||||
fn pow_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.first().unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
|
|
|
@ -7,7 +7,7 @@ license = "MIT OR Apache-2.0"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.82"
|
||||
anyhow = "1.0.81"
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
derive_more = { version = "0.99.17", features = ["display"], default-features = false }
|
||||
env_logger = "0.11.3"
|
||||
|
|
Loading…
Reference in New Issue