Support for CumSum in ONNX models. (#1340)

This commit is contained in:
Laurent Mazare 2023-11-17 22:03:40 +00:00 committed by GitHub
parent 9ab3f9729f
commit d31f11035f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 0 deletions

View File

@ -741,6 +741,25 @@ pub fn simple_eval(
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
"CumSum" => {
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
.copied()
.unwrap_or(0);
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
if exclusive != 0 {
bail!("only exclusive == 0 is supported in CumSum")
}
if reverse != 0 {
bail!("only reverse == 0 is supported in CumSum")
}
let input = get(&node.input[0])?;
let axis = get(&node.input[1])?
.to_dtype(DType::U32)?
.to_vec0::<u32>()?;
let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}