mirror of https://github.com/tracel-ai/burn.git
feat: adding shape support for gather ONNX operation (#2128)
This commit is contained in:
parent
0802d063d8
commit
27ca6cee95
|
@ -810,11 +810,6 @@ fn gather_update_outputs(node: &mut Node) {
|
||||||
panic!("Gather requires two inputs: data and indices");
|
panic!("Gather requires two inputs: data and indices");
|
||||||
}
|
}
|
||||||
|
|
||||||
let input_tensor = match &node.inputs[0].ty {
|
|
||||||
ArgType::Tensor(tensor) => tensor,
|
|
||||||
ty => panic!("Only tensor input is valid but received: {:?}", ty),
|
|
||||||
};
|
|
||||||
|
|
||||||
let indices_tensor = match &node.inputs[1].ty {
|
let indices_tensor = match &node.inputs[1].ty {
|
||||||
ArgType::Tensor(tensor) => tensor,
|
ArgType::Tensor(tensor) => tensor,
|
||||||
_ => panic!("Only tensor indices is valid"),
|
_ => panic!("Only tensor indices is valid"),
|
||||||
|
@ -824,12 +819,28 @@ fn gather_update_outputs(node: &mut Node) {
|
||||||
panic!("Gather: indices tensor rank above 1 not supported")
|
panic!("Gather: indices tensor rank above 1 not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match &node.inputs[0].ty {
|
||||||
|
ArgType::Tensor(input_tensor) => {
|
||||||
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
||||||
let output_rank = indices_tensor.dim + input_tensor.dim - 1;
|
let output_rank = indices_tensor.dim + input_tensor.dim - 1;
|
||||||
|
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||||
|
elem_type: input_tensor.elem_type.clone(),
|
||||||
dim: output_rank,
|
dim: output_rank,
|
||||||
shape: None,
|
shape: None,
|
||||||
elem_type: input_tensor.elem_type.clone(),
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
ArgType::Shape(_dim) => {
|
||||||
|
let shape_dim = 1;
|
||||||
|
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
||||||
|
let output_rank = indices_tensor.dim + shape_dim - 1;
|
||||||
|
|
||||||
|
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||||
|
elem_type: ElementType::Int64,
|
||||||
|
dim: output_rank,
|
||||||
|
shape: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ty => panic!("Only tensor/shape input is valid but received: {:?}", ty),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue