mirror of https://github.com/tracel-ai/burn.git
Allow ONNX scalar greater/less with scalar (#2146)
This commit is contained in:
parent
e75eebfc31
commit
12caca7909
|
@ -38,12 +38,16 @@ fn main() {
|
|||
.input("tests/gelu/gelu.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.input("tests/greater/greater.onnx")
|
||||
.input("tests/greater/greater_scalar.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal_scalar.onnx")
|
||||
.input("tests/hard_sigmoid/hard_sigmoid.onnx")
|
||||
.input("tests/layer_norm/layer_norm.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
.input("tests/less/less.onnx")
|
||||
.input("tests/less/less_scalar.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal_scalar.onnx")
|
||||
.input("tests/linear/linear.onnx")
|
||||
.input("tests/log/log.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/greater/greater.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x,y)
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "greater_scalar.onnx"
|
||||
|
||||
test_input1 = torch.randn(4, 4, device=device)
|
||||
test_input2 = torch.tensor(1.0)
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
print("Test input data: {} {}".format(test_input1, test_input2))
|
||||
output = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/less_or_equal/less_or_equal.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.ge(x,y)
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "greater_or_equal_scalar.onnx"
|
||||
|
||||
test_input1 = torch.randn(4, 4, device=device)
|
||||
test_input2 = torch.tensor(1.0)
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
print("Test input data: {} {}".format(test_input1, test_input2))
|
||||
output = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/less/less.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x,y)
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "less_scalar.onnx"
|
||||
|
||||
test_input1 = torch.randn(4, 4, device=device)
|
||||
test_input2 = torch.tensor(1.0)
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
print("Test input data: {} {}".format(test_input1, test_input2))
|
||||
output = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/less_or_equal/less_or_equal.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.le(x,y)
|
||||
|
||||
def main():
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
onnx_name = "less_or_equal_scalar.onnx"
|
||||
|
||||
test_input1 = torch.randn(4, 4, device=device)
|
||||
test_input2 = torch.tensor(1.0)
|
||||
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
print("Test input data: {} {}".format(test_input1, test_input2))
|
||||
output = model.forward(test_input1, test_input2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -47,12 +47,16 @@ include_models!(
|
|||
gelu,
|
||||
global_avr_pool,
|
||||
greater,
|
||||
greater_scalar,
|
||||
greater_or_equal,
|
||||
greater_or_equal_scalar,
|
||||
hard_sigmoid,
|
||||
layer_norm,
|
||||
leaky_relu,
|
||||
less,
|
||||
less_scalar,
|
||||
less_or_equal,
|
||||
less_or_equal_scalar,
|
||||
linear,
|
||||
log,
|
||||
log_softmax,
|
||||
|
@ -1642,6 +1646,20 @@ mod tests {
|
|||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn greater_scalar() {
|
||||
let device = Default::default();
|
||||
let model: greater_scalar::Model<Backend> = greater_scalar::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 0.5]], &device);
|
||||
let input2 = 1.0;
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected = TensorData::from([[false, true, true, false]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn less() {
|
||||
let device = Default::default();
|
||||
|
@ -1656,6 +1674,20 @@ mod tests {
|
|||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn less_scalar() {
|
||||
let device = Default::default();
|
||||
let model: less_scalar::Model<Backend> = less_scalar::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 0.5]], &device);
|
||||
let input2 = 1.0;
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected = TensorData::from([[false, false, false, true]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn greater_or_equal() {
|
||||
let device = Default::default();
|
||||
|
@ -1670,6 +1702,21 @@ mod tests {
|
|||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn greater_or_equal_scalar() {
|
||||
let device = Default::default();
|
||||
let model: greater_or_equal_scalar::Model<Backend> =
|
||||
greater_or_equal_scalar::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 0.5]], &device);
|
||||
let input2 = 1.0;
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected = TensorData::from([[true, true, true, false]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn less_or_equal() {
|
||||
let device = Default::default();
|
||||
|
@ -1684,6 +1731,20 @@ mod tests {
|
|||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn less_or_equal_scalar() {
|
||||
let device = Default::default();
|
||||
let model: less_or_equal_scalar::Model<Backend> = less_or_equal_scalar::Model::new(&device);
|
||||
|
||||
let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 0.5]], &device);
|
||||
let input2 = 1.0;
|
||||
|
||||
let output = model.forward(input1, input2);
|
||||
let expected = TensorData::from([[true, false, false, true]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_creation_with_a_default_device() {
|
||||
let device = Default::default();
|
||||
|
|
|
@ -206,7 +206,14 @@ impl BinaryNode {
|
|||
pub(crate) fn greater(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.greater(#rhs) },
|
||||
_ => panic!("greater is supported for tensor only"),
|
||||
(Type::Tensor(_), Type::Scalar(_)) => {
|
||||
move |lhs, rhs| quote! { #lhs.greater_elem(#rhs) }
|
||||
}
|
||||
(Type::Scalar(_), Type::Tensor(_)) => {
|
||||
// L > R == R < L
|
||||
move |lhs, rhs| quote! { #rhs.lower_elem(#lhs) }
|
||||
}
|
||||
(lhs, rhs) => panic!("greater is not supported for {lhs:?} > {rhs:?}"),
|
||||
};
|
||||
Self::new(lhs, rhs, output, BinaryType::Greater, Arc::new(function))
|
||||
}
|
||||
|
@ -216,7 +223,14 @@ impl BinaryNode {
|
|||
(Type::Tensor(_), Type::Tensor(_)) => {
|
||||
move |lhs, rhs| quote! { #lhs.greater_equal(#rhs) }
|
||||
}
|
||||
_ => panic!("greater_equal is supported for tensor only"),
|
||||
(Type::Tensor(_), Type::Scalar(_)) => {
|
||||
move |lhs, rhs| quote! { #lhs.greater_equal_elem(#rhs) }
|
||||
}
|
||||
(Type::Scalar(_), Type::Tensor(_)) => {
|
||||
// L >= R == R <= L
|
||||
move |lhs, rhs| quote! { #rhs.lower_equal_elem(#lhs) }
|
||||
}
|
||||
(lhs, rhs) => panic!("greater_equal is not supported for {lhs:?} > {rhs:?}"),
|
||||
};
|
||||
Self::new(
|
||||
lhs,
|
||||
|
@ -230,7 +244,12 @@ impl BinaryNode {
|
|||
pub(crate) fn lower(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower(#rhs) },
|
||||
_ => panic!("lower is supported for tensor only"),
|
||||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.lower_elem(#rhs) },
|
||||
(Type::Scalar(_), Type::Tensor(_)) => {
|
||||
// L < R == R > L
|
||||
move |lhs, rhs| quote! { #rhs.greater_elem(#lhs) }
|
||||
}
|
||||
(lhs, rhs) => panic!("lower is not supported for {lhs:?} > {rhs:?}"),
|
||||
};
|
||||
Self::new(lhs, rhs, output, BinaryType::Less, Arc::new(function))
|
||||
}
|
||||
|
@ -238,7 +257,14 @@ impl BinaryNode {
|
|||
pub(crate) fn lower_equal(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower_equal(#rhs) },
|
||||
_ => panic!("lower_equal is supported for tensor only"),
|
||||
(Type::Tensor(_), Type::Scalar(_)) => {
|
||||
move |lhs, rhs| quote! { #lhs.lower_equal_elem(#rhs) }
|
||||
}
|
||||
(Type::Scalar(_), Type::Tensor(_)) => {
|
||||
// L <= R == R >= L
|
||||
move |lhs, rhs| quote! { #rhs.greater_equal_elem(#lhs) }
|
||||
}
|
||||
(lhs, rhs) => panic!("lower_equal is not supported for {lhs:?} > {rhs:?}"),
|
||||
};
|
||||
Self::new(
|
||||
lhs,
|
||||
|
@ -418,21 +444,41 @@ mod tests {
|
|||
test_binary_operator_on_tensors!(greater);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_greater_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(greater, greater_elem);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_greater_or_equal() {
|
||||
test_binary_operator_on_tensors!(greater_equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_greater_or_equal_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(greater_equal, greater_equal_elem);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_less() {
|
||||
test_binary_operator_on_tensors!(lower);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_less_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(lower, lower_elem);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_less_or_equal() {
|
||||
test_binary_operator_on_tensors!(lower_equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_less_or_equal_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(lower_equal, lower_equal_elem);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_equal_tensors() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
|
Loading…
Reference in New Issue