feat: adding shape support for gather ONNX operation (#2128)

This commit is contained in:
mepatrick73 2024-08-08 13:18:03 -04:00 committed by GitHub
parent 0802d063d8
commit 27ca6cee95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 12 deletions

View File

@ -810,11 +810,6 @@ fn gather_update_outputs(node: &mut Node) {
panic!("Gather requires two inputs: data and indices"); panic!("Gather requires two inputs: data and indices");
} }
let input_tensor = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor,
ty => panic!("Only tensor input is valid but received: {:?}", ty),
};
let indices_tensor = match &node.inputs[1].ty { let indices_tensor = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => tensor, ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor indices is valid"), _ => panic!("Only tensor indices is valid"),
@ -824,12 +819,28 @@ fn gather_update_outputs(node: &mut Node) {
panic!("Gather: indices tensor rank above 1 not supported") panic!("Gather: indices tensor rank above 1 not supported")
} }
match &node.inputs[0].ty {
ArgType::Tensor(input_tensor) => {
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
let output_rank = indices_tensor.dim + input_tensor.dim - 1; let output_rank = indices_tensor.dim + input_tensor.dim - 1;
node.outputs[0].ty = ArgType::Tensor(TensorType { node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: input_tensor.elem_type.clone(),
dim: output_rank, dim: output_rank,
shape: None, shape: None,
elem_type: input_tensor.elem_type.clone(),
}); });
} }
ArgType::Shape(_dim) => {
let shape_dim = 1;
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
let output_rank = indices_tensor.dim + shape_dim - 1;
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Int64,
dim: output_rank,
shape: None,
})
}
ty => panic!("Only tensor/shape input is valid but received: {:?}", ty),
}
}