mirror of https://github.com/tracel-ai/burn.git
Fix indices dim check in gather_update_outputs (#2149)
This commit is contained in:
parent
12caca7909
commit
0eec293e28
|
@ -810,19 +810,20 @@ fn gather_update_outputs(node: &mut Node) {
|
|||
panic!("Gather requires two inputs: data and indices");
|
||||
}
|
||||
|
||||
let indices_tensor = match &node.inputs[1].ty {
|
||||
ArgType::Tensor(tensor) => tensor,
|
||||
_ => panic!("Only tensor indices is valid"),
|
||||
let indices_dim = match &node.inputs[1].ty {
|
||||
ArgType::Tensor(tensor) => tensor.dim,
|
||||
ArgType::Scalar(_) => 0,
|
||||
_ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty),
|
||||
};
|
||||
|
||||
if indices_tensor.dim > 1 {
|
||||
if indices_dim > 1 {
|
||||
panic!("Gather: indices tensor rank above 1 not supported")
|
||||
}
|
||||
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(input_tensor) => {
|
||||
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
||||
let output_rank = indices_tensor.dim + input_tensor.dim - 1;
|
||||
let output_rank = indices_dim + input_tensor.dim - 1;
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: input_tensor.elem_type.clone(),
|
||||
|
@ -833,7 +834,7 @@ fn gather_update_outputs(node: &mut Node) {
|
|||
ArgType::Shape(_dim) => {
|
||||
let shape_dim = 1;
|
||||
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
|
||||
let output_rank = indices_tensor.dim + shape_dim - 1;
|
||||
let output_rank = indices_dim + shape_dim - 1;
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Int64,
|
||||
|
|
Loading…
Reference in New Issue