forked from OSchip/llvm-project
[mlir][vector] Order parallel indices before transposing the input in multireductions
The current code does not preserve the order of the parallel dimensions when doing multi-reductions and thus we can end up in scenarios where the result shape does not match the desired shape after reduction. This patch fixes that by ensuring that the parallel indices are in order and then concatenates them to the reduction dimensions so that the reduction dimensions are innermost. Differential Revision: https://reviews.llvm.org/D104884
This commit is contained in:
parent
ab546ead3b
commit
0d6e4199e3
|
@ -37,6 +37,7 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -3915,42 +3916,33 @@ struct InnerDimReductionConversion
|
|||
auto loc = multiReductionOp.getLoc();
|
||||
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||
|
||||
auto reductionDims = llvm::to_vector<4>(
|
||||
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
|
||||
[](Attribute attr) -> int64_t {
|
||||
return attr.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
llvm::sort(reductionDims);
|
||||
|
||||
int64_t reductionSize = multiReductionOp.reduction_dims().size();
|
||||
|
||||
// Fails if already inner most reduction.
|
||||
bool innerMostReduction = true;
|
||||
for (int i = 0; i < reductionSize; ++i) {
|
||||
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
|
||||
innerMostReduction = false;
|
||||
// Separate reduction and parallel dims
|
||||
auto reductionDimsRange =
|
||||
multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
|
||||
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
|
||||
reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
|
||||
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
|
||||
reductionDims.end());
|
||||
int64_t reductionSize = reductionDims.size();
|
||||
SmallVector<int64_t, 4> parallelDims;
|
||||
for (int64_t i = 0; i < srcRank; i++) {
|
||||
if (!reductionDimsSet.contains(i))
|
||||
parallelDims.push_back(i);
|
||||
}
|
||||
}
|
||||
if (innerMostReduction)
|
||||
|
||||
// Add transpose only if inner-most dimensions are not reductions
|
||||
if (parallelDims ==
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))
|
||||
return failure();
|
||||
|
||||
// Permutes the indices so reduction dims are inner most dims.
|
||||
SmallVector<int64_t> indices;
|
||||
for (int i = 0; i < srcRank; ++i) {
|
||||
indices.push_back(i);
|
||||
}
|
||||
int ir = reductionSize - 1;
|
||||
int id = srcRank - 1;
|
||||
while (ir >= 0) {
|
||||
std::swap(indices[reductionDims[ir--]], indices[id--]);
|
||||
}
|
||||
|
||||
// Sets inner most dims as reduction.
|
||||
SmallVector<int64_t, 4> indices;
|
||||
indices.append(parallelDims.begin(), parallelDims.end());
|
||||
indices.append(reductionDims.begin(), reductionDims.end());
|
||||
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
|
||||
SmallVector<bool> reductionMask(srcRank, false);
|
||||
for (int i = 0; i < reductionSize; ++i) {
|
||||
reductionMask[srcRank - i - 1] = true;
|
||||
}
|
||||
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
|
||||
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
|
||||
multiReductionOp, transposeOp.result(), reductionMask,
|
||||
multiReductionOp.kind());
|
||||
|
|
|
@ -61,6 +61,49 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
|
|||
// CHECK-LABEL: func @vector_multi_reduction_transposed
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32>
|
||||
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
|
||||
// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
|
||||
// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
|
||||
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
|
||||
return %0 : vector<2x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction_ordering
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32>
|
||||
// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<8xf32>
|
||||
// CHECK: %[[C0:.+]] = constant 0 : i32
|
||||
// CHECK: %[[C1:.+]] = constant 1 : i32
|
||||
// CHECK: %[[C2:.+]] = constant 2 : i32
|
||||
// CHECK: %[[C3:.+]] = constant 3 : i32
|
||||
// CHECK: %[[C4:.+]] = constant 4 : i32
|
||||
// CHECK: %[[C5:.+]] = constant 5 : i32
|
||||
// CHECK: %[[C6:.+]] = constant 6 : i32
|
||||
// CHECK: %[[C7:.+]] = constant 7 : i32
|
||||
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
|
||||
// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
|
||||
// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
|
||||
// CHECK: %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
|
||||
// CHECK: %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
|
||||
// CHECK: %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
|
||||
// CHECK: %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
|
||||
// CHECK: %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
|
||||
// CHECK: %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32>
|
||||
// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
|
||||
// CHECK: return %[[RESHAPED_VEC]]
|
||||
|
|
Loading…
Reference in New Issue