feat: added reduce min onnx import (#1894)

This commit is contained in:
jachym 2024-06-18 15:04:24 +02:00 committed by GitHub
parent bbd35fd457
commit 96468fc3c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 210 additions and 1 deletions

View File

@ -141,7 +141,7 @@ represent the corresponding Burn Op.
| [ReduceLogSumExp][134] | ❌ | ❌ | | [ReduceLogSumExp][134] | ❌ | ❌ |
| [ReduceMax][135] | ✅ | ✅ | | [ReduceMax][135] | ✅ | ✅ |
| [ReduceMean][136] | ✅ | ✅ | | [ReduceMean][136] | ✅ | ✅ |
| [ReduceMin][137] | | ✅ | | [ReduceMin][137] | | ✅ |
| [ReduceProd][138] | ❌ | ✅ | | [ReduceProd][138] | ❌ | ✅ |
| [ReduceSum][139] | ✅ | ✅ | | [ReduceSum][139] | ✅ | ✅ |
| [ReduceSumSquare][140] | ❌ | ❌ | | [ReduceSumSquare][140] | ❌ | ❌ |

View File

@ -52,6 +52,7 @@ fn main() {
.input("tests/leaky_relu/leaky_relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/prelu/prelu.onnx") .input("tests/prelu/prelu.onnx")
.input("tests/reduce_max/reduce_max.onnx") .input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_min/reduce_min.onnx")
.input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx") .input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx") .input("tests/reduce_sum/reduce_sum_opset11.onnx")

View File

@ -62,6 +62,7 @@ include_models!(
range, range,
recip, recip,
reduce_max, reduce_max,
reduce_min,
reduce_mean, reduce_mean,
reduce_sum_opset13, reduce_sum_opset13,
reduce_sum_opset11, reduce_sum_opset11,
@ -728,6 +729,22 @@ mod tests {
assert_eq!(output_value.to_data(), expected); assert_eq!(output_value.to_data(), expected);
} }
#[test]
fn reduce_min() {
let device = Default::default();
let model: reduce_min::Model<Backend> = reduce_min::Model::new(&device);
// Run the models
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = Data::from([1.]);
let expected = Data::from([[[[1.]]]]);
assert_eq!(output_scalar.to_data(), expected_scalar);
assert_eq!(output_tensor.to_data(), input.to_data());
assert_eq!(output_value.to_data(), expected);
}
#[test] #[test]
fn reduce_mean() { fn reduce_mean() {
let device = Default::default(); let device = Default::default();

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/reduce_min/reduce_min.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return (
# ReduceMin, keepdims=0, axes=None
torch.min(x),
# ReduceMin, keepdims=1, axes=[1]
torch.min(x, dim=1, keepdim=True).values,
# ReduceMin, keepdims=1, axes=[-1]
torch.min(x, dim=-1, keepdim=True).values,
)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "reduce_min.onnx"
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
print(f"Finished exporting model to {onnx_name}")
# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")
if __name__ == "__main__":
main()

View File

@ -33,6 +33,7 @@ pub enum UnaryNodeKind {
Neg, Neg,
Not, Not,
ReduceMax, ReduceMax,
ReduceMin,
ReduceMean, ReduceMean,
ReduceSum, ReduceSum,
Reciprocal, Reciprocal,
@ -62,6 +63,7 @@ impl UnaryNodeKind {
Self::Neg => "neg", Self::Neg => "neg",
Self::Not => "not", Self::Not => "not",
Self::ReduceMax => "reduce_max", Self::ReduceMax => "reduce_max",
Self::ReduceMin => "reduce_min",
Self::ReduceMean => "reduce_mean", Self::ReduceMean => "reduce_mean",
Self::ReduceSum => "reduce_sum", Self::ReduceSum => "reduce_sum",
Self::Reciprocal => "reciprocal", Self::Reciprocal => "reciprocal",
@ -331,6 +333,35 @@ impl UnaryNode {
} }
} }
pub(crate) fn reduce_min(input: Type, output: Type, dim: Option<usize>) -> Self {
if let Type::Tensor(ref tensor) = output {
if let Some(dim) = dim {
if tensor.kind == TensorKind::Bool {
// Min is only implemented on numeric tensors
panic!("ReduceMin is not supported for boolean");
}
// ReduceMin, keepdims=1, axes=[dim]
let dim = dim.to_tokens();
Self::new(
input,
output,
UnaryNodeKind::ReduceMin,
Rc::new(move |input| quote! { #input.min_dim(#dim) }),
)
} else {
// ReduceMin, keepdims=0, axes=None
Self::new(
input,
output,
UnaryNodeKind::ReduceMin,
Rc::new(move |input| quote! { #input.min() }),
)
}
} else {
panic!("ReduceMin only supports tensor output");
}
}
pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option<usize>) -> Self { pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option<usize>) -> Self {
// ReduceMean is constrained to numeric tensors, so no need to check for bool. // ReduceMean is constrained to numeric tensors, so no need to check for bool.
if let Type::Tensor(_) = output { if let Type::Tensor(_) = output {
@ -629,6 +660,43 @@ mod tests {
); );
} }
#[test]
fn test_unary_codegen_reduce_min() {
one_node_graph(
UnaryNode::reduce_min(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Some(1),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.min_dim(1);
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
one_node_graph(
UnaryNode::reduce_min(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
None,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.min();
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test] #[test]
fn test_unary_codegen_reduce_mean() { fn test_unary_codegen_reduce_mean() {
one_node_graph( one_node_graph(

View File

@ -55,6 +55,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Range => range_update_outputs(node), NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node), NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node), NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMin => reduce_min_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node), NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::ReduceSum => reduce_sum_update_outputs(node), NodeType::ReduceSum => reduce_sum_update_outputs(node),
NodeType::Relu => same_as_input(node), NodeType::Relu => same_as_input(node),
@ -716,6 +717,30 @@ fn reduce_max_update_outputs(node: &mut Node) {
} }
} }
fn reduce_min_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("ReduceMin: multiple inputs are not supported");
}
let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
let dim_only = match node.attrs.get("axes") {
Some(value) => match &value {
AttributeValue::Int64(_) => true,
AttributeValue::Int64s(ints) => ints.len() == 1,
_ => false,
},
None => false,
};
if dim_only {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}
/// Infers the shape of a ReduceSum node and replaces the shape of the output tensor. /// Infers the shape of a ReduceSum node and replaces the shape of the output tensor.
fn reduce_sum_update_outputs(node: &mut Node) { fn reduce_sum_update_outputs(node: &mut Node) {
let node_input = &mut node.inputs[0]; let node_input = &mut node.inputs[0];

View File

@ -902,6 +902,48 @@ pub fn reduce_max_config(node: &Node) -> Option<usize> {
} }
} }
pub fn reduce_min_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
// Extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => axes = value.clone().into_i64s(),
"keepdims" => keepdims = value.clone().into_i64(),
_ => {}
}
}
if axes.len() > 1 {
panic!("ReduceMin: reducing on multiple dimensions is not supported")
}
if axes.is_empty() && keepdims == 1 {
panic!("ReduceMin: axes must be provided with keepdims")
}
if !axes.is_empty() && keepdims == 0 {
panic!("ReduceMin: the reduce operation must preserve the reduced dimension")
}
if axes.is_empty() {
None
} else {
let mut dim = axes[0];
if dim < 0 {
dim += tensor.dim as i64;
}
Some(dim as usize)
}
}
pub fn reduce_mean_config(node: &Node) -> Option<usize> { pub fn reduce_mean_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new(); let mut axes = Vec::new();
let mut keepdims = 1; let mut keepdims = 1;

View File

@ -289,6 +289,7 @@ impl OnnxGraph {
NodeType::Min => graph.register(Self::min_conversion(node)), NodeType::Min => graph.register(Self::min_conversion(node)),
NodeType::Range => graph.register(Self::range_conversion(node)), NodeType::Range => graph.register(Self::range_conversion(node)),
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)), NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMin => graph.register(Self::reduce_min_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)), NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)), NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
@ -640,6 +641,14 @@ impl OnnxGraph {
UnaryNode::reduce_max(input, output, dim) UnaryNode::reduce_max(input, output, dim)
} }
fn reduce_min_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let dim = reduce_min_config(&node);
UnaryNode::reduce_min(input, output, dim)
}
fn reduce_mean_conversion(node: Node) -> UnaryNode { fn reduce_mean_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type(); let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type();