mirror of https://github.com/tracel-ai/burn.git
pad-input-fix: adding support for pads as attributes (#2195)
* pad-input-fix: adding support for pads as attributes * fix: making asked changes * clippy fix
This commit is contained in:
parent
f5a1eca3ce
commit
2c12d58cd8
|
@ -768,9 +768,20 @@ pub fn tile_config(node: &Node) -> TileConfig {
|
|||
|
||||
/// Create a PadConfig from the attributes of the node
|
||||
pub fn pad_config(node: &Node) -> PadConfig {
|
||||
fn get_pads_input(node: &Node) -> Vec<i64> {
|
||||
// If the input is not provided, return an empty vector
|
||||
if node.inputs.get(1).is_none() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
match &node.inputs[1].value {
|
||||
Some(Data::Int64s(shape)) => shape.clone(),
|
||||
_ => panic!("Tensor data type must be int64"),
|
||||
}
|
||||
}
|
||||
fn get_pads(node: &Node) -> Vec<usize> {
|
||||
if node.inputs.len() < 2 {
|
||||
panic!("Pad: must provide at least two inputs")
|
||||
if node.inputs.is_empty() {
|
||||
panic!("Pad: must provide data as input")
|
||||
}
|
||||
if node.inputs.len() >= 4 {
|
||||
panic!("Pad: axes input is not supported")
|
||||
|
@ -781,19 +792,41 @@ pub fn pad_config(node: &Node) -> PadConfig {
|
|||
_ => panic!("Pad: Only tensor input is valid"),
|
||||
};
|
||||
|
||||
let pads: Vec<usize> = match &node.inputs[1].value {
|
||||
Some(Data::Int64s(shape)) => shape
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
if x < 0 {
|
||||
// TODO: support negative pads
|
||||
panic!("Pad: Negative pad is not supported");
|
||||
//TODO : handle more possible attributes
|
||||
let mut pads: Vec<usize> = get_pads_input(node)
|
||||
.into_iter()
|
||||
.map(|x| x as usize)
|
||||
.collect();
|
||||
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"pads" => {
|
||||
pads = value
|
||||
.clone()
|
||||
.into_i64s()
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
if x < 0 {
|
||||
panic!("Pad: Negative pad is not supported");
|
||||
}
|
||||
x as usize
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
"mode" => {
|
||||
let mode = value.clone().into_string();
|
||||
if mode != "constant" {
|
||||
panic!("only constant mode is supported, given mode is {}", mode);
|
||||
}
|
||||
x as usize
|
||||
})
|
||||
.collect(),
|
||||
_ => panic!("Pad: pads data type must be int64"),
|
||||
};
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if pads.is_empty() {
|
||||
panic!("Pad: pads should be given as attribute or as input");
|
||||
}
|
||||
|
||||
if pads.len() != input_dim * 2 {
|
||||
panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]");
|
||||
|
@ -823,7 +856,7 @@ pub fn pad_config(node: &Node) -> PadConfig {
|
|||
}
|
||||
fn get_constant_value(node: &Node) -> f32 {
|
||||
// TODO: support int, boolean
|
||||
node.inputs
|
||||
let mut constant_value = node.inputs
|
||||
.get(2)
|
||||
.and_then(|input| match &input.value {
|
||||
Some(Data::Float16s(constant_value)) => {
|
||||
|
@ -840,7 +873,15 @@ pub fn pad_config(node: &Node) -> PadConfig {
|
|||
Some(Data::Float64(constant_value)) => Some(*constant_value as f32),
|
||||
_ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"),
|
||||
})
|
||||
.unwrap_or(0.0)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
if node.attrs.contains_key("value") {
|
||||
constant_value = node.attrs.get("value").map(|value| match value {
|
||||
AttributeValue::Float32(value) => *value,
|
||||
_ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"),
|
||||
}).expect("constant_value should have had a value now");
|
||||
}
|
||||
constant_value
|
||||
}
|
||||
|
||||
let pads = get_pads(node);
|
||||
|
|
Loading…
Reference in New Issue