[mlir][vector] Order parallel indices before transposing the input in multireductions

The current code does not preserve the order of the parallel
dimensions when doing multi-reductions and thus we can end
up in scenarios where the result shape does not match the
desired shape after reduction.

This patch fixes that by ensuring that the parallel indices
are in order and then concatenates them to the reduction dimensions
so that the reduction dimensions are innermost.

Differential Revision: https://reviews.llvm.org/D104884
This commit is contained in:
harsh-nod 2021-06-28 18:40:49 -07:00 committed by thomasraoux
parent ab546ead3b
commit 0d6e4199e3
2 changed files with 65 additions and 30 deletions

View File

@ -37,6 +37,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@ -3915,42 +3916,33 @@ struct InnerDimReductionConversion
auto loc = multiReductionOp.getLoc();
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
auto reductionDims = llvm::to_vector<4>(
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
[](Attribute attr) -> int64_t {
return attr.cast<IntegerAttr>().getInt();
}));
llvm::sort(reductionDims);
int64_t reductionSize = multiReductionOp.reduction_dims().size();
// Fails if already inner most reduction.
bool innerMostReduction = true;
for (int i = 0; i < reductionSize; ++i) {
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
innerMostReduction = false;
}
// Separate reduction and parallel dims
auto reductionDimsRange =
multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
reductionDims.end());
int64_t reductionSize = reductionDims.size();
SmallVector<int64_t, 4> parallelDims;
for (int64_t i = 0; i < srcRank; i++) {
if (!reductionDimsSet.contains(i))
parallelDims.push_back(i);
}
if (innerMostReduction)
// Add transpose only if inner-most dimensions are not reductions
if (parallelDims ==
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))
return failure();
// Permutes the indices so reduction dims are inner most dims.
SmallVector<int64_t> indices;
for (int i = 0; i < srcRank; ++i) {
indices.push_back(i);
}
int ir = reductionSize - 1;
int id = srcRank - 1;
while (ir >= 0) {
std::swap(indices[reductionDims[ir--]], indices[id--]);
}
// Sets inner most dims as reduction.
SmallVector<int64_t, 4> indices;
indices.append(parallelDims.begin(), parallelDims.end());
indices.append(reductionDims.begin(), reductionDims.end());
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
SmallVector<bool> reductionMask(srcRank, false);
for (int i = 0; i < reductionSize; ++i) {
reductionMask[srcRank - i - 1] = true;
}
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
multiReductionOp, transposeOp.result(), reductionMask,
multiReductionOp.kind());

View File

@ -61,6 +61,49 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
// CHECK-LABEL: func @vector_multi_reduction_transposed
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32>
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
// CHECK: return %[[RESULT]]
func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
return %0 : vector<2x4xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_ordering
// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>
// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<8xf32>
// CHECK: %[[C0:.+]] = constant 0 : i32
// CHECK: %[[C1:.+]] = constant 1 : i32
// CHECK: %[[C2:.+]] = constant 2 : i32
// CHECK: %[[C3:.+]] = constant 3 : i32
// CHECK: %[[C4:.+]] = constant 4 : i32
// CHECK: %[[C5:.+]] = constant 5 : i32
// CHECK: %[[C6:.+]] = constant 6 : i32
// CHECK: %[[C7:.+]] = constant 7 : i32
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
// CHECK: %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
// CHECK: %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32>
// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
// CHECK: %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32>
// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
// CHECK: %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32>
// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
// CHECK: %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32>
// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
// CHECK: %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32>
// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
// CHECK: return %[[RESHAPED_VEC]]