From 0d6e4199e32a3a5942f920bf13c0a0ddf10d2579 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Mon, 28 Jun 2021 18:40:49 -0700 Subject: [PATCH] [mlir][vector] Order parallel indices before transposing the input in multireductions The current code does not preserve the order of the parallel dimensions when doing multi-reductions and thus we can end up in scenarios where the result shape does not match the desired shape after reduction. This patch fixes that by ensuring that the parallel indices are in order and then concatenates them to the reduction dimensions so that the reduction dimensions are innermost. Differential Revision: https://reviews.llvm.org/D104884 --- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 50 ++++++++----------- .../vector-multi-reduction-lowering.mlir | 45 ++++++++++++++++- 2 files changed, 65 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index e04d48d6ca84..3a10fb3de641 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -37,6 +37,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -3915,42 +3916,33 @@ struct InnerDimReductionConversion auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - auto reductionDims = llvm::to_vector<4>( - llvm::map_range(multiReductionOp.reduction_dims().cast(), - [](Attribute attr) -> int64_t { - return attr.cast().getInt(); - })); - llvm::sort(reductionDims); - - int64_t reductionSize = multiReductionOp.reduction_dims().size(); - - // Fails if already inner most reduction. - bool innerMostReduction = true; - for (int i = 0; i < reductionSize; ++i) { - if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) { - innerMostReduction = false; - } + // Separate reduction and parallel dims + auto reductionDimsRange = + multiReductionOp.reduction_dims().getAsValueRange(); + auto reductionDims = llvm::to_vector<4>(llvm::map_range( + reductionDimsRange, [](APInt a) { return a.getZExtValue(); })); + llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), + reductionDims.end()); + int64_t reductionSize = reductionDims.size(); + SmallVector parallelDims; + for (int64_t i = 0; i < srcRank; i++) { + if (!reductionDimsSet.contains(i)) + parallelDims.push_back(i); } - if (innerMostReduction) + + // Add transpose only if inner-most dimensions are not reductions + if (parallelDims == + llvm::to_vector<4>(llvm::seq(0, parallelDims.size()))) return failure(); - // Permutes the indices so reduction dims are inner most dims. - SmallVector indices; - for (int i = 0; i < srcRank; ++i) { - indices.push_back(i); - } - int ir = reductionSize - 1; - int id = srcRank - 1; - while (ir >= 0) { - std::swap(indices[reductionDims[ir--]], indices[id--]); - } - - // Sets inner most dims as reduction. + SmallVector indices; + indices.append(parallelDims.begin(), parallelDims.end()); + indices.append(reductionDims.begin(), reductionDims.end()); + auto transposeOp = rewriter.create(loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { reductionMask[srcRank - i - 1] = true; } - auto transposeOp = rewriter.create(loc, src, indices); rewriter.replaceOpWithNewOp( multiReductionOp, transposeOp.result(), reductionMask, multiReductionOp.kind()); diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index 6cfc4e035719..4121262722e3 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -61,6 +61,49 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x // CHECK-LABEL: func @vector_multi_reduction_transposed // CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32> // CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32> -// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> +// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] + +func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> + return %0 : vector<2x4xf32> +} +// CHECK-LABEL: func @vector_multi_reduction_ordering +// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32> +// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<8xf32> +// CHECK: %[[C0:.+]] = constant 0 : i32 +// CHECK: %[[C1:.+]] = constant 1 : i32 +// CHECK: %[[C2:.+]] = constant 2 : i32 +// CHECK: %[[C3:.+]] = constant 3 : i32 +// CHECK: %[[C4:.+]] = constant 4 : i32 +// CHECK: %[[C5:.+]] = constant 5 : i32 +// CHECK: %[[C6:.+]] = constant 6 : i32 +// CHECK: %[[C7:.+]] = constant 7 : i32 +// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0] +// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1] +// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2] +// CHECK: %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3] +// CHECK: %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32> +// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0] +// CHECK: %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32> +// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1] +// CHECK: %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32> +// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2] +// CHECK: %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32> +// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3] +// CHECK: %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32> +// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> +// CHECK: return %[[RESHAPED_VEC]]