Support for CumSum in ONNX models. (#1340)
This commit is contained in:
parent
9ab3f9729f
commit
d31f11035f
|
@ -741,6 +741,25 @@ pub fn simple_eval(
|
|||
let output = input.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
|
||||
"CumSum" => {
|
||||
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
|
||||
if exclusive != 0 {
|
||||
bail!("only exclusive == 0 is supported in CumSum")
|
||||
}
|
||||
if reverse != 0 {
|
||||
bail!("only reverse == 0 is supported in CumSum")
|
||||
}
|
||||
let input = get(&node.input[0])?;
|
||||
let axis = get(&node.input[1])?
|
||||
.to_dtype(DType::U32)?
|
||||
.to_vec0::<u32>()?;
|
||||
let output = input.cumsum(axis as usize)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue