mirror of https://github.com/tracel-ai/burn.git
Add reduce mean ONNX op support (#1637)
* Add reduce mean onnx op support * Fix comment
This commit is contained in:
parent
340a12463a
commit
d5f20e2711
|
@ -140,7 +140,7 @@ represent the corresponding Burn Op.
|
||||||
| [ReduceLogSum][133] | ❌ | ❌ |
|
| [ReduceLogSum][133] | ❌ | ❌ |
|
||||||
| [ReduceLogSumExp][134] | ❌ | ❌ |
|
| [ReduceLogSumExp][134] | ❌ | ❌ |
|
||||||
| [ReduceMax][135] | ❌ | ✅ |
|
| [ReduceMax][135] | ❌ | ✅ |
|
||||||
| [ReduceMean][136] | ❌ | ✅ |
|
| [ReduceMean][136] | ✅ | ✅ |
|
||||||
| [ReduceMin][137] | ❌ | ✅ |
|
| [ReduceMin][137] | ❌ | ✅ |
|
||||||
| [ReduceProd][138] | ❌ | ✅ |
|
| [ReduceProd][138] | ❌ | ✅ |
|
||||||
| [ReduceSum][139] | ❌ | ✅ |
|
| [ReduceSum][139] | ❌ | ✅ |
|
||||||
|
|
|
@ -35,6 +35,7 @@ fn main() {
|
||||||
.input("tests/recip/recip.onnx")
|
.input("tests/recip/recip.onnx")
|
||||||
.input("tests/relu/relu.onnx")
|
.input("tests/relu/relu.onnx")
|
||||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||||
|
.input("tests/reduce_mean/reduce_mean.onnx")
|
||||||
.input("tests/reshape/reshape.onnx")
|
.input("tests/reshape/reshape.onnx")
|
||||||
.input("tests/sigmoid/sigmoid.onnx")
|
.input("tests/sigmoid/sigmoid.onnx")
|
||||||
.input("tests/sin/sin.onnx")
|
.input("tests/sin/sin.onnx")
|
||||||
|
|
|
@ -41,6 +41,7 @@ include_models!(
|
||||||
mul,
|
mul,
|
||||||
neg,
|
neg,
|
||||||
recip,
|
recip,
|
||||||
|
reduce_mean,
|
||||||
relu,
|
relu,
|
||||||
reshape,
|
reshape,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
|
@ -443,6 +444,22 @@ mod tests {
|
||||||
output3.to_data().assert_approx_eq(&expected3, 3);
|
output3.to_data().assert_approx_eq(&expected3, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_mean() {
|
||||||
|
let device = Default::default();
|
||||||
|
let model: reduce_mean::Model<Backend> = reduce_mean::Model::new(&device);
|
||||||
|
|
||||||
|
// Run the model
|
||||||
|
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
|
||||||
|
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
|
||||||
|
let expected_scalar = Data::from([9.75]);
|
||||||
|
let expected = Data::from([[[[9.75]]]]);
|
||||||
|
|
||||||
|
assert_eq!(output_scalar.to_data(), expected_scalar);
|
||||||
|
assert_eq!(output_tensor.to_data(), input.to_data());
|
||||||
|
assert_eq!(output_value.to_data(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reshape() {
|
fn reshape() {
|
||||||
// Initialize the model without weights (because the exported file does not contain them)
|
// Initialize the model without weights (because the exported file does not contain them)
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,46 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# used to generate model: onnx-tests/tests/reduce_mean/reduce_mean.onnx
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (
|
||||||
|
# ReduceMean, keepdims=0, axes=None
|
||||||
|
torch.mean(x),
|
||||||
|
# ReduceMean, keepdims=1, axes=[1]
|
||||||
|
torch.mean(x, dim=1, keepdim=True),
|
||||||
|
# ReduceMean, keepdims=1, axes=[-1]
|
||||||
|
torch.mean(x, dim=-1, keepdim=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# Export to onnx
|
||||||
|
model = Model()
|
||||||
|
model.eval()
|
||||||
|
device = torch.device("cpu")
|
||||||
|
onnx_name = "reduce_mean.onnx"
|
||||||
|
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)
|
||||||
|
|
||||||
|
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16)
|
||||||
|
|
||||||
|
print(f"Finished exporting model to {onnx_name}")
|
||||||
|
|
||||||
|
# Output some test data for use in the test
|
||||||
|
print(f"Test input data: {test_input}")
|
||||||
|
output = model.forward(*test_input)
|
||||||
|
print(f"Test output data: {output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -29,6 +29,7 @@ pub enum UnaryNodeKind {
|
||||||
Log,
|
Log,
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
Neg,
|
Neg,
|
||||||
|
ReduceMean,
|
||||||
Reciprocal,
|
Reciprocal,
|
||||||
LeakyRelu,
|
LeakyRelu,
|
||||||
Relu,
|
Relu,
|
||||||
|
@ -52,6 +53,7 @@ impl UnaryNodeKind {
|
||||||
Self::Log => "log",
|
Self::Log => "log",
|
||||||
Self::LogSoftmax => "log_softmax",
|
Self::LogSoftmax => "log_softmax",
|
||||||
Self::Neg => "neg",
|
Self::Neg => "neg",
|
||||||
|
Self::ReduceMean => "reduce_mean",
|
||||||
Self::Reciprocal => "reciprocal",
|
Self::Reciprocal => "reciprocal",
|
||||||
Self::LeakyRelu => "leaky_relu",
|
Self::LeakyRelu => "leaky_relu",
|
||||||
Self::Relu => "relu",
|
Self::Relu => "relu",
|
||||||
|
@ -253,6 +255,32 @@ impl UnaryNode {
|
||||||
_ => panic!("output must be a tensor"),
|
_ => panic!("output must be a tensor"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option<usize>) -> Self {
|
||||||
|
// ReduceMean is constrained to numeric tensors, so no need to check for bool.
|
||||||
|
if let Type::Tensor(_) = output {
|
||||||
|
if let Some(dim) = dim {
|
||||||
|
// ReduceMean, keepdims=1, axes=[dim]
|
||||||
|
let dim = dim.to_tokens();
|
||||||
|
Self::new(
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
UnaryNodeKind::ReduceMean,
|
||||||
|
Rc::new(move |input| quote! { #input.mean_dim(#dim) }),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// ReduceMean, keepdims=0, axes=None
|
||||||
|
Self::new(
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
UnaryNodeKind::ReduceMean,
|
||||||
|
Rc::new(move |input| quote! { #input.mean() }),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic!("ReduceMean only supports tensor output");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -437,6 +465,43 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unary_codegen_reduce_mean() {
|
||||||
|
one_node_graph(
|
||||||
|
UnaryNode::reduce_mean(
|
||||||
|
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||||
|
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||||
|
Some(1),
|
||||||
|
),
|
||||||
|
quote! {
|
||||||
|
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||||
|
let tensor2 = tensor1.mean_dim(1);
|
||||||
|
|
||||||
|
tensor2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
vec!["tensor1".to_string()],
|
||||||
|
vec!["tensor2".to_string()],
|
||||||
|
);
|
||||||
|
|
||||||
|
one_node_graph(
|
||||||
|
UnaryNode::reduce_mean(
|
||||||
|
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||||
|
Type::Tensor(TensorType::new_float("tensor2", 1)),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
quote! {
|
||||||
|
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
|
||||||
|
let tensor2 = tensor1.mean();
|
||||||
|
|
||||||
|
tensor2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
vec!["tensor1".to_string()],
|
||||||
|
vec!["tensor2".to_string()],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unary_codegen_reciprocal() {
|
fn test_unary_codegen_reciprocal() {
|
||||||
one_node_graph(
|
one_node_graph(
|
||||||
|
|
|
@ -39,7 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
||||||
NodeType::Mul => same_as_input(node),
|
NodeType::Mul => same_as_input(node),
|
||||||
NodeType::Neg => same_as_input(node),
|
NodeType::Neg => same_as_input(node),
|
||||||
NodeType::Reciprocal => same_as_input(node),
|
NodeType::Reciprocal => same_as_input(node),
|
||||||
NodeType::ReduceMean => mean_update_outputs(node),
|
NodeType::ReduceMean => reduce_mean_update_outputs(node),
|
||||||
NodeType::Relu => same_as_input(node),
|
NodeType::Relu => same_as_input(node),
|
||||||
NodeType::Reshape => reshape_update_outputs(node),
|
NodeType::Reshape => reshape_update_outputs(node),
|
||||||
NodeType::Shape => shape_update_outputs(node),
|
NodeType::Shape => shape_update_outputs(node),
|
||||||
|
@ -206,12 +206,11 @@ fn reshape_update_outputs(node: &mut Node) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mean_update_outputs(node: &mut Node) {
|
fn reduce_mean_update_outputs(node: &mut Node) {
|
||||||
if node.inputs.len() != 1 {
|
if node.inputs.len() != 1 {
|
||||||
panic!("Mean: multiple inputs are not supported");
|
panic!("Mean: multiple inputs are not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the configuration of the linear layer (inputs are known)
|
|
||||||
let node_input = &mut node.inputs[0];
|
let node_input = &mut node.inputs[0];
|
||||||
let tensor = match node_input.clone().ty {
|
let tensor = match node_input.clone().ty {
|
||||||
ArgType::Tensor(tensor) => tensor,
|
ArgType::Tensor(tensor) => tensor,
|
||||||
|
@ -230,6 +229,9 @@ fn mean_update_outputs(node: &mut Node) {
|
||||||
if dim_only {
|
if dim_only {
|
||||||
node.outputs[0].ty = ArgType::Tensor(tensor);
|
node.outputs[0].ty = ArgType::Tensor(tensor);
|
||||||
} else {
|
} else {
|
||||||
|
// NOTE: ReduceMean w/o keepdims reduces to a scalar value, but Burn doesn't have
|
||||||
|
// 0-dim tensor so we can't track or perform other ops on that value
|
||||||
|
// node.outputs[0].ty = ArgType::Scalar(tensor.elem_type);
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
|
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -660,3 +660,47 @@ fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d {
|
||||||
panic!("Padding configuration ({:?}) not supported", pads);
|
panic!("Padding configuration ({:?}) not supported", pads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn reduce_mean_config(node: &Node) -> Option<usize> {
|
||||||
|
let mut axes = Vec::new();
|
||||||
|
let mut keepdims = 1;
|
||||||
|
|
||||||
|
let tensor = match node.inputs.first().unwrap().clone().ty {
|
||||||
|
ArgType::Tensor(tensor) => tensor,
|
||||||
|
_ => panic!("Only tensor input is valid"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract the attributes
|
||||||
|
for (key, value) in node.attrs.iter() {
|
||||||
|
match key.as_str() {
|
||||||
|
"axes" => axes = value.clone().into_i64s(),
|
||||||
|
"keepdims" => keepdims = value.clone().into_i64(),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if axes.len() > 1 {
|
||||||
|
panic!("ReduceMean: reducing on multiple dimensions is not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
if axes.is_empty() && keepdims == 1 {
|
||||||
|
panic!("ReduceMean: axes must be provided with keepdims")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !axes.is_empty() && keepdims == 0 {
|
||||||
|
// Not supported in Burn
|
||||||
|
panic!("ReduceMean: the reduce operation must preserve the reduced dimension")
|
||||||
|
}
|
||||||
|
|
||||||
|
if axes.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut dim = axes[0];
|
||||||
|
|
||||||
|
if dim < 0 {
|
||||||
|
// Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim
|
||||||
|
dim += tensor.dim as i64;
|
||||||
|
}
|
||||||
|
Some(dim as usize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -252,6 +252,7 @@ impl OnnxGraph {
|
||||||
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
|
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
|
||||||
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
|
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
|
||||||
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
||||||
|
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
|
||||||
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
||||||
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
|
||||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||||
|
@ -463,6 +464,15 @@ impl OnnxGraph {
|
||||||
|
|
||||||
ReshapeNode::new(input, output, shape)
|
ReshapeNode::new(input, output, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reduce_mean_conversion(node: Node) -> UnaryNode {
|
||||||
|
let input = node.inputs.first().unwrap().to_type();
|
||||||
|
let output = node.outputs.first().unwrap().to_type();
|
||||||
|
let dim = reduce_mean_config(&node);
|
||||||
|
|
||||||
|
UnaryNode::reduce_mean(input, output, dim)
|
||||||
|
}
|
||||||
|
|
||||||
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
||||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||||
|
|
Loading…
Reference in New Issue