From 53c77ae6463cac080ee26b4f2415ecb058b67544 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 23 Jul 2024 10:58:20 -0400 Subject: [PATCH] Convert compatible prelu weights to rank 1 (#2054) --- crates/burn-import/src/onnx/to_burn.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 3b32ba9c4..77ec889ea 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -974,9 +974,19 @@ impl ParsedOnnxGraph { fn prelu_conversion(node: Node) -> PReluNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let weight = extract_data_serialize::(1, &node).unwrap(); + let mut weight = extract_data_serialize::(1, &node).unwrap(); let config = PReluConfig::new(); let name = &node.name; + + if weight.shape.len() > 1 { + if weight.shape[1..].iter().product::() == 1 { + // Burn accepts rank 1 alpha weight + weight.shape = weight.shape[..1].to_vec(); + } else { + panic!("Invalid PRelu weight with shape {:?}", weight.shape); + } + } + PReluNode::new(name, input, output, weight, config) } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode {