pad-input-fix: adding support for pads as attributes (#2195)

* pad-input-fix: adding support for pads as attributes

* fix: making asked changes

* clippy fix
This commit is contained in:
mepatrick73 2024-08-23 12:46:14 -04:00 committed by GitHub
parent f5a1eca3ce
commit 2c12d58cd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 57 additions and 16 deletions

View File

@ -768,9 +768,20 @@ pub fn tile_config(node: &Node) -> TileConfig {
/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads_input(node: &Node) -> Vec<i64> {
// If the input is not provided, return an empty vector
if node.inputs.get(1).is_none() {
return Vec::new();
}
match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape.clone(),
_ => panic!("Tensor data type must be int64"),
}
}
fn get_pads(node: &Node) -> Vec<usize> {
if node.inputs.len() < 2 {
panic!("Pad: must provide at least two inputs")
if node.inputs.is_empty() {
panic!("Pad: must provide data as input")
}
if node.inputs.len() >= 4 {
panic!("Pad: axes input is not supported")
@ -781,19 +792,41 @@ pub fn pad_config(node: &Node) -> PadConfig {
_ => panic!("Pad: Only tensor input is valid"),
};
let pads: Vec<usize> = match &node.inputs[1].value {
Some(Data::Int64s(shape)) => shape
.iter()
.map(|&x| {
if x < 0 {
// TODO: support negative pads
panic!("Pad: Negative pad is not supported");
//TODO : handle more possible attributes
let mut pads: Vec<usize> = get_pads_input(node)
.into_iter()
.map(|x| x as usize)
.collect();
for (key, value) in node.attrs.iter() {
match key.as_str() {
"pads" => {
pads = value
.clone()
.into_i64s()
.iter()
.map(|&x| {
if x < 0 {
panic!("Pad: Negative pad is not supported");
}
x as usize
})
.collect()
}
"mode" => {
let mode = value.clone().into_string();
if mode != "constant" {
panic!("only constant mode is supported, given mode is {}", mode);
}
x as usize
})
.collect(),
_ => panic!("Pad: pads data type must be int64"),
};
}
_ => {}
}
}
if pads.is_empty() {
panic!("Pad: pads should be given as attribute or as input");
}
if pads.len() != input_dim * 2 {
panic!("Pad: pads should be a 1D tensor of shape [2 * num_axes]");
@ -823,7 +856,7 @@ pub fn pad_config(node: &Node) -> PadConfig {
}
fn get_constant_value(node: &Node) -> f32 {
// TODO: support int, boolean
node.inputs
let mut constant_value = node.inputs
.get(2)
.and_then(|input| match &input.value {
Some(Data::Float16s(constant_value)) => {
@ -840,7 +873,15 @@ pub fn pad_config(node: &Node) -> PadConfig {
Some(Data::Float64(constant_value)) => Some(*constant_value as f32),
_ => panic!("Pad: only float values are currently supported for constant value, submit an issue on github"),
})
.unwrap_or(0.0)
.unwrap_or(0.0);
if node.attrs.contains_key("value") {
constant_value = node.attrs.get("value").map(|value| match value {
AttributeValue::Float32(value) => *value,
_ => panic!("Pad: only float32 values are currently supported for constant value as attribute, submit an issue on github"),
}).expect("constant_value should have had a value now");
}
constant_value
}
let pads = get_pads(node);