diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 1fab060a6b62..5e94633a7b60 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -830,19 +830,22 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, Value input = op->getOperand(0); llvm::SmallVector reduceShape; + SmallVector dynDims; for (unsigned i = 0; i < inputTy.getRank(); i++) { - if (axis != i) + if (axis != i) { reduceShape.push_back(inputTy.getDimSize(i)); + if (inputTy.isDynamicDim(i)) + dynDims.push_back(rewriter.create(loc, input, i)); + } } Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType()); // First fill the output buffer with the init value. - auto initTensor = - rewriter - .create(loc, ArrayRef({}), reduceShape, - resultTy.getElementType()) - .result(); + auto initTensor = rewriter + .create(loc, dynDims, reduceShape, + resultTy.getElementType()) + .result(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 3706c4131dcb..27487a4b8e8b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -736,6 +736,72 @@ func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // ----- +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: @reduce_float_dyn +func @reduce_float_dyn(%arg0: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 4] + // CHECK: %[[CST0:.+]] = arith.constant 0.0 + // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST0]], %[[INIT]]) + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) + // CHECK: ^bb0(%arg1: f32, %arg2: f32) + // CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32 + // CHECK: linalg.yield %[[RES]] : f32 + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor + // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor + %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor) -> tensor + return +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: @reduce_float_dyn_nonzero_batch +func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () { + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, %[[DYN]]] + // CHECK: %[[CST1:.+]] = arith.constant 1.0 + // CHECK: %[[FILL:.+]] = linalg.fill(%[[CST1]], %[[INIT]]) + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<5x?x4xf32>) outs(%[[FILL]] : tensor<5x?xf32>) + // CHECK: ^bb0(%arg1: f32, %arg2: f32) + // CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32 + // CHECK: linalg.yield %[[RES]] : f32 + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor + // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor<5x?x1xf32> + %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> + return +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: @reduce_float_dyn_multiple +func @reduce_float_dyn_multiple(%arg0: tensor) -> () { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]] + // CHECK: %[[CMIN:.+]] = arith.constant -3.40282347E+38 + // CHECK: %[[FILL:.+]] = linalg.fill(%[[CMIN]], %[[INIT]]) + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%[[FILL]] : tensor) + // CHECK: ^bb0(%arg1: f32, %arg2: f32) + // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32 + // CHECK: %[[RES:.+]] = select %[[CMP]], %arg1, %arg2 : f32 + // CHECK: linalg.yield %[[RES]] : f32 + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor + %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor) -> tensor + return +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>