[ONNX] Add not op and extend cast support to tensors (#1634)

* Add not onnx op support

* Extend cast onnx support to tensors

* Fix clippy
This commit is contained in:
Guillaume Lagrange 2024-04-16 08:45:25 -04:00 committed by GitHub
parent 7377bbe31c
commit 6d96e8d808
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 340 additions and 34 deletions

8
Cargo.lock generated
View File

@ -126,9 +126,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.82"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
[[package]]
name = "arboard"
@ -3517,9 +3517,9 @@ dependencies = [
[[package]]
name = "quote"
version = "1.0.36"
version = "1.0.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
dependencies = [
"proc-macro2",
]

View File

@ -47,7 +47,7 @@ pretty_assertions = "1.4"
proc-macro2 = "1.0.79"
protobuf = "3.3"
protobuf-codegen = "3.3"
quote = "1.0.36"
quote = "1.0.33"
percent-encoding = "2.3.1"
r2d2 = "0.8.10"
r2d2_sqlite = { version = "0.23.0" }

View File

@ -118,7 +118,7 @@ represent the corresponding Burn Op.
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
| [NonMaxSuppression][112] | ❌ | ❌ |
| [NonZero][113] | ❌ | ❌ |
| [Not][114] | | ✅ |
| [Not][114] | | ✅ |
| [OneHot][115] | ❌ | ✅ |
| [Optional][116] | ❌ | ❌ |
| [OptionalGetElement][117] | ❌ | ❌ |

View File

@ -10,6 +10,7 @@ fn main() {
.input("tests/add/add.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/cast/cast.onnx")
.input("tests/clip/clip_opset16.onnx")
.input("tests/clip/clip_opset7.onnx")
.input("tests/concat/concat.onnx")
@ -32,6 +33,7 @@ fn main() {
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")

Binary file not shown.

View File

@ -0,0 +1,64 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/cast/cast.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(
self,
x_bool,
x_int,
x_float,
x_scalar,
):
# NOTE: we clone same-type casts for int and bool, otherwise the exporter would
# link other type casts to the output of the bool cast, leading to additional casts
return (
x_bool.clone().bool(),
x_bool.int(),
x_bool.float(),
x_int.bool(),
x_int.clone().int(),
x_int.float(),
x_float.bool(),
x_float.int(),
x_float.float(),
x_scalar.int(),
)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "cast.onnx"
test_bool = torch.ones((2, 1), device=device, dtype=torch.bool)
test_int = torch.ones((2, 1), device=device, dtype=torch.int)
test_float = torch.ones((2, 1), device=device, dtype=torch.float)
test_scalar = torch.ones(1, device=device, dtype=torch.float).squeeze()
test_input = (test_bool, test_int, test_float, test_scalar)
# NOTE: torch exports logical_not with a cast node even if the input is already bool
# https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
print(f"Finished exporting model to {onnx_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,19 @@
pytorch2.1.2:©
6
onnx::Cast_0/Cast_output_0/Cast"Cast*
to  

/Cast_output_02/Not"Not
main_graphZ&
onnx::Cast_0
 



b
2
 



B

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/not/not.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.logical_not(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "not.onnx"
test_input = torch.tensor([[[[True, False, True, False]]]], device=device)
# NOTE: torch exports logical_not with a cast node even if the input is already bool
# https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L2204-L2207
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
print(f"Finished exporting model to {onnx_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(test_input)
print(f"Test output data: {output}")
if __name__ == "__main__":
main()

View File

@ -17,6 +17,7 @@ include_models!(
add,
avg_pool2d,
batch_norm,
cast,
clip_opset16,
clip_opset7,
concat,
@ -40,6 +41,7 @@ include_models!(
maxpool2d,
mul,
neg,
not,
recip,
reduce_mean,
relu,
@ -64,7 +66,7 @@ mod tests {
use super::*;
use burn::tensor::{Data, Int, Shape, Tensor};
use burn::tensor::{Bool, Data, Int, Shape, Tensor};
use float_cmp::ApproxEq;
@ -854,6 +856,22 @@ mod tests {
assert_eq!(output2, expected2);
}
#[test]
fn not() {
let device = Default::default();
let model: not::Model<Backend> = not::Model::new(&device);
let input = Tensor::<Backend, 4, Bool>::from_bool(
Data::from([[[[true, false, true, false]]]]),
&device,
);
let output = model.forward(input).to_data();
let expected = Data::from([[[[false, true, false, true]]]]);
assert_eq!(output, expected);
}
#[test]
fn test_model_creation_with_a_default_device() {
let device = Default::default();
@ -908,4 +926,52 @@ mod tests {
let output = model.forward(input);
assert_eq!(output.shape(), expected_shape);
}
#[test]
fn cast() {
let device = Default::default();
let model: cast::Model<Backend> = cast::Model::new(&device);
let input_bool =
Tensor::<Backend, 2, Bool>::from_bool(Data::from([[true], [true]]), &device);
let input_int = Tensor::<Backend, 2, Int>::from_ints([[1], [1]], &device);
let input_float = Tensor::<Backend, 2>::from_floats([[1.], [1.]], &device);
let input_scalar = 1f32;
let (
output1,
output2,
output3,
output4,
output5,
output6,
output7,
output8,
output9,
output_scalar,
) = model.forward(
input_bool.clone(),
input_int.clone(),
input_float.clone(),
input_scalar,
);
let expected_bool = input_bool.to_data();
let expected_int = input_int.to_data();
let expected_float = input_float.to_data();
let expected_scalar = 1;
assert_eq!(output1.to_data(), expected_bool);
assert_eq!(output2.to_data(), expected_int);
output3.to_data().assert_approx_eq(&expected_float, 4);
assert_eq!(output4.to_data(), expected_bool);
assert_eq!(output5.to_data(), expected_int);
output6.to_data().assert_approx_eq(&expected_float, 4);
assert_eq!(output7.to_data(), expected_bool);
assert_eq!(output8.to_data(), expected_int);
output9.to_data().assert_approx_eq(&expected_float, 4);
assert_eq!(output_scalar, expected_scalar);
}
}

View File

@ -1,5 +1,5 @@
use super::{Node, NodeCodegen};
use crate::burn::{BurnImports, Scope, ToTokens, Type};
use crate::burn::{BurnImports, Scope, TensorKind, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
@ -20,7 +20,8 @@ pub struct UnaryNode {
/// Type of unary node.
#[derive(Clone)]
pub enum UnaryNodeKind {
Cast,
// Input and output tensor types (required for codegen imports)
Cast(Option<TensorKind>, Option<TensorKind>),
Cos,
Erf,
Exp,
@ -29,6 +30,7 @@ pub enum UnaryNodeKind {
Log,
LogSoftmax,
Neg,
Not,
ReduceMean,
Reciprocal,
LeakyRelu,
@ -44,7 +46,7 @@ pub enum UnaryNodeKind {
impl UnaryNodeKind {
pub fn as_str(&self) -> &str {
match self {
Self::Cast => "cast",
Self::Cast(..) => "cast",
Self::Cos => "cos",
Self::Erf => "erf",
Self::Exp => "exp",
@ -53,6 +55,7 @@ impl UnaryNodeKind {
Self::Log => "log",
Self::LogSoftmax => "log_softmax",
Self::Neg => "neg",
Self::Not => "not",
Self::ReduceMean => "reduce_mean",
Self::Reciprocal => "reciprocal",
Self::LeakyRelu => "leaky_relu",
@ -120,6 +123,17 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
UnaryNodeKind::Neg => {
imports.register("core::ops::Neg");
}
UnaryNodeKind::Not => {
imports.register("burn::tensor::Bool");
}
UnaryNodeKind::Cast(Some(input_kind), Some(output_kind)) => {
if input_kind == TensorKind::Bool || output_kind == TensorKind::Bool {
imports.register("burn::tensor::Bool");
}
if input_kind == TensorKind::Int || output_kind == TensorKind::Int {
imports.register("burn::tensor::Int");
}
}
_ => {}
}
}
@ -217,42 +231,61 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Neg, Rc::new(function))
}
pub(crate) fn not(input: Type, output: Type) -> Self {
// Not ONNX operator is constrained to bool tensors, so no need to check the type.
let function = move |input| quote! { #input.bool_not()};
Self::new(input, output, UnaryNodeKind::Not, Rc::new(function))
}
/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
/// 1) scalar -> scalar
///
/// TODO: Implement the following conversions:
/// 2) tensor int -> tensor float
/// 3) tensor float -> tensor int
/// 4) tensor -> scalar
/// 5) scalar -> tensor
pub(crate) fn cast(input: Type, output: Type) -> Self {
match (input.clone(), output.clone()) {
(Type::Scalar(input_scalar), Type::Scalar(output_scalar)) => {
if input_scalar.kind == output_scalar.kind {
// If the input and output types are the same, we don't need to cast.
Self::new(input, output, UnaryNodeKind::Cast, Rc::new(|input| input))
Self::new(
input,
output,
UnaryNodeKind::Cast(None, None),
Rc::new(|input| input),
)
} else {
// If the input and output types are different, we need to cast.
let ty = output_scalar.ty();
Self::new(
input,
output,
UnaryNodeKind::Cast,
UnaryNodeKind::Cast(None, None),
Rc::new(move |input| quote! { #input as #ty }),
)
}
}
(Type::Tensor(_input_tensor), Type::Tensor(_output_tensor)) => {
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
// TODO: If the input is scalar and the output type is a tensor,
// we should generate another code block. (@antimora 8/4/2023)
// Tensor::from_data(Data::from([#input]).convert()).unsqueeze();
todo!()
}
(Type::Tensor(input_tensor), Type::Tensor(output_tensor)) => {
if input_tensor.kind == output_tensor.kind {
// If the input and output types are the same, we don't need to cast.
Self::new(
input,
output,
UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)),
Rc::new(|input| input),
)
} else {
// If the input and output types are different, we need to cast.
let function = match output_tensor.kind {
TensorKind::Bool => move |input| quote! { #input.bool()},
TensorKind::Int => move |input| quote! { #input.int()},
TensorKind::Float => move |input| quote! { #input.float()},
};
_ => panic!("output must be a tensor"),
Self::new(
input,
output,
UnaryNodeKind::Cast(Some(input_tensor.kind), Some(output_tensor.kind)),
Rc::new(function),
)
}
}
_ => panic!("output must be a tensor or scalar"),
}
}
@ -553,6 +586,51 @@ mod tests {
vec!["scalar1".to_string()],
vec!["scalar2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_int("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4, Int> {
let tensor2 = tensor1.int();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Tensor(TensorType::new_int("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4, Int>) -> Tensor<B, 4> {
let tensor2 = tensor1.float();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
one_node_graph(
UnaryNode::cast(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_bool("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4, Bool> {
let tensor2 = tensor1.bool();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
@ -687,4 +765,23 @@ mod tests {
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_not() {
one_node_graph(
UnaryNode::not(
Type::Tensor(TensorType::new_bool("tensor1", 4)),
Type::Tensor(TensorType::new_bool("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4, Bool>) -> Tensor<B, 4, Bool> {
let tensor2 = tensor1.bool_not();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
}

View File

@ -13,7 +13,7 @@ pub struct TensorType {
pub shape: Option<Vec<usize>>,
}
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorKind {
Int,
Float,

View File

@ -38,6 +38,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::MaxPool2d => same_as_input(node),
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
NodeType::Not => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::Relu => same_as_input(node),
@ -135,6 +136,7 @@ fn cast_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Cast: multiple inputs are not supported");
}
let input = &mut node.inputs[0];
let output = &mut node.outputs[0];
// Extract cast type and update the output tensor
@ -145,6 +147,7 @@ fn cast_update_outputs(node: &mut Node) {
DataType::INT32 => ElementType::Int32,
DataType::INT64 => ElementType::Int64,
DataType::DOUBLE => ElementType::Float64,
DataType::BOOL => ElementType::Bool,
_ => panic!("Cast: unsupported type"),
},
_ => panic!("'to' attribute must be an Int64"),
@ -152,19 +155,25 @@ fn cast_update_outputs(node: &mut Node) {
None => panic!("Constant node must have a value attribute"),
};
match output.ty.clone() {
match input.ty.clone() {
ArgType::Tensor(tensor) => {
if tensor.dim == 0 {
// treat 0-dim tensor as scalar
output.ty = ArgType::Scalar(elem_type);
input.ty = ArgType::Scalar(tensor.elem_type);
} else {
todo!("Cast: support casting from different tensor types");
// Cast input and output are the same shape, but possibly different types
output.ty = ArgType::Tensor(TensorType {
elem_type,
dim: tensor.dim,
shape: tensor.shape.clone(),
});
}
}
ArgType::Scalar(_scalar) => {
output.ty = ArgType::Scalar(elem_type);
}
_ => panic!("Cast: only scalar input is valid"),
_ => panic!("Cast: only scalar and tensor inputs are valid"),
}
}

View File

@ -237,6 +237,7 @@ impl OnnxGraph {
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
NodeType::Not => graph.register(Self::not_conversion(node)),
NodeType::Linear => graph.register(Self::linear_conversion::<PS>(node)),
NodeType::BatchNormalization => {
graph.register(Self::batch_norm_conversion::<PS>(node))
@ -697,6 +698,13 @@ impl OnnxGraph {
let output = node.outputs.first().unwrap().to_type();
UnaryNode::neg(input, output)
}
fn not_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
UnaryNode::not(input, output)
}
fn pow_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();

View File

@ -7,7 +7,7 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.82"
anyhow = "1.0.81"
clap = { version = "4.5.4", features = ["derive"] }
derive_more = { version = "0.99.17", features = ["display"], default-features = false }
env_logger = "0.11.3"