diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 8acff2934..557e29112 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -810,11 +810,6 @@ fn gather_update_outputs(node: &mut Node) { panic!("Gather requires two inputs: data and indices"); } - let input_tensor = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor, - ty => panic!("Only tensor input is valid but received: {:?}", ty), - }; - let indices_tensor = match &node.inputs[1].ty { ArgType::Tensor(tensor) => tensor, _ => panic!("Only tensor indices is valid"), @@ -824,12 +819,28 @@ fn gather_update_outputs(node: &mut Node) { panic!("Gather: indices tensor rank above 1 not supported") } - // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input - let output_rank = indices_tensor.dim + input_tensor.dim - 1; + match &node.inputs[0].ty { + ArgType::Tensor(input_tensor) => { + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_tensor.dim + input_tensor.dim - 1; - node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: output_rank, - shape: None, - elem_type: input_tensor.elem_type.clone(), - }); + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: input_tensor.elem_type.clone(), + dim: output_rank, + shape: None, + }); + } + ArgType::Shape(_dim) => { + let shape_dim = 1; + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_tensor.dim + shape_dim - 1; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Int64, + dim: output_rank, + shape: None, + }) + } + ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), + } }