mirror of https://github.com/tracel-ai/burn.git
Convert compatible prelu weights to rank 1 (#2054)
This commit is contained in:
parent
4c7353230e
commit
53c77ae646
|
@ -974,9 +974,19 @@ impl ParsedOnnxGraph {
|
|||
fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
|
||||
let mut weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
|
||||
let config = PReluConfig::new();
|
||||
let name = &node.name;
|
||||
|
||||
if weight.shape.len() > 1 {
|
||||
if weight.shape[1..].iter().product::<usize>() == 1 {
|
||||
// Burn accepts rank 1 alpha weight
|
||||
weight.shape = weight.shape[..1].to_vec();
|
||||
} else {
|
||||
panic!("Invalid PRelu weight with shape {:?}", weight.shape);
|
||||
}
|
||||
}
|
||||
|
||||
PReluNode::new(name, input, output, weight, config)
|
||||
}
|
||||
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode {
|
||||
|
|
Loading…
Reference in New Issue