Convert compatible prelu weights to rank 1 (#2054)

This commit is contained in:
Guillaume Lagrange 2024-07-23 10:58:20 -04:00 committed by GitHub
parent 4c7353230e
commit 53c77ae646
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 1 deletions

View File

@ -974,9 +974,19 @@ impl ParsedOnnxGraph {
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let mut weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = PReluConfig::new();
let name = &node.name;
if weight.shape.len() > 1 {
if weight.shape[1..].iter().product::<usize>() == 1 {
// Burn accepts rank 1 alpha weight
weight.shape = weight.shape[..1].to_vec();
} else {
panic!("Invalid PRelu weight with shape {:?}", weight.shape);
}
}
PReluNode::new(name, input, output, weight, config)
}
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {