Add patterns to lower vector.multi_reduction into a sequence of vector.reduction

Three patterns are added to convert into vector.multi_reduction into a
sequence of vector.reduction as the following:

- Transpose the inputs so inner most dimensions are always reduction.
- Reduce rank of vector.multi_reduction into 2d with inner most
reduction dim (get the 2d canical form)
- 2D canonical form is converted into a sequence of vector.reduction.

There are two things we might worth in a follow up diff:

- An scf.for (maybe optionally) around vector.reduction instead of unrolling it.
- Breakdown the vector.reduction into a sequence of vector.reduction
(e.g tree-based reduction) instead of relying on how downstream dialects
handle it.
  Note: this will requires passing target-vector-length

Differential Revision: https://reviews.llvm.org/D101570
This commit is contained in:
Ahmed Taei 2021-04-29 14:05:23 -07:00
parent 6e6ae6c727
commit 499e89fc91
4 changed files with 287 additions and 0 deletions

View File

@ -92,6 +92,10 @@ void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
// Collect a set of patterns to convert vector.multi_reduction op into
// a sequence of vector.reduction ops.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr

View File

@ -3575,6 +3575,198 @@ private:
const bool enableIndexOptimizations;
};
// Converts vector.multi_reduction into inner-most reduction form by inserting
// vector.transpose
struct InnerDimReductionConversion
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto src = multiReductionOp.source();
auto loc = multiReductionOp.getLoc();
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
auto reductionDims = llvm::to_vector<4>(
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
[](Attribute attr) -> int64_t {
return attr.cast<IntegerAttr>().getInt();
}));
llvm::sort(reductionDims);
int64_t reductionSize = multiReductionOp.reduction_dims().size();
// Fails if already inner most reduction.
bool innerMostReduction = true;
for (int i = 0; i < reductionSize; ++i) {
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
innerMostReduction = false;
}
}
if (innerMostReduction)
return failure();
// Permutes the indices so reduction dims are inner most dims.
SmallVector<int64_t> indices;
for (int i = 0; i < srcRank; ++i) {
indices.push_back(i);
}
int ir = reductionSize - 1;
int id = srcRank - 1;
while (ir >= 0) {
std::swap(indices[reductionDims[ir--]], indices[id--]);
}
// Sets inner most dims as reduction.
SmallVector<bool> reductionMask(srcRank, false);
for (int i = 0; i < reductionSize; ++i) {
reductionMask[srcRank - i - 1] = true;
}
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
multiReductionOp, transposeOp.result(), reductionMask,
multiReductionOp.kind());
return success();
}
};
// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
// dimensions are inner most.
struct ReduceMultiDimReductionRank
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
if (srcRank == 2)
return failure();
auto loc = multiReductionOp.getLoc();
auto reductionDims = llvm::to_vector<4>(
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
[](Attribute attr) -> int64_t {
return attr.cast<IntegerAttr>().getInt();
}));
llvm::sort(reductionDims);
// Fails if not inner most reduction.
int64_t reductionSize = reductionDims.size();
bool innerMostReduction = true;
for (int i = 0; i < reductionSize; ++i) {
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
innerMostReduction = false;
}
}
if (!innerMostReduction)
return failure();
// Extracts 2d rank reduction shape.
int innerDims = 1;
int outterDims = 1;
SmallVector<int64_t> innerDimsShape;
for (int i = 0; i < srcRank; ++i) {
if (i < (srcRank - reductionSize)) {
innerDims *= srcShape[i];
innerDimsShape.push_back(srcShape[i]);
} else {
outterDims *= srcShape[i];
}
}
// Creates shape cast for the inputs n_d -> 2d
auto castedType = VectorType::get(
{innerDims, outterDims},
multiReductionOp.getSourceVectorType().getElementType());
auto castedOp = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.source());
// Creates the canonical form of 2d vector.multi_reduction with inner most
// dim as reduction.
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
loc, castedOp.result(), ArrayRef<bool>{false, true},
multiReductionOp.kind());
// Creates shape cast for the output 2d -> nd
auto outputCastedType = VectorType::get(
innerDimsShape,
multiReductionOp.getSourceVectorType().getElementType());
Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
loc, outputCastedType, newOp.dest());
rewriter.replaceOp(multiReductionOp, castedOutputOp);
return success();
}
};
// Converts 2d vector.multi_reduction with inner most reduction dimension into a
// sequence of vector.reduction ops.
struct TwoDimMultiReductionToReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
if (srcRank != 2)
return failure();
if (multiReductionOp.getReductionMask()[0] ||
!multiReductionOp.getReductionMask()[1])
return failure();
auto loc = multiReductionOp.getLoc();
Value result =
multiReductionOp.getDestVectorType().getElementType().isIntOrIndex()
? rewriter.create<ConstantOp>(
loc, multiReductionOp.getDestVectorType(),
DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
0))
: rewriter.create<ConstantOp>(
loc, multiReductionOp.getDestVectorType(),
DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
0.0f));
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
// TODO: Add vector::CombiningKind attribute instead of string to
// vector.reduction.
auto getKindStr = [](vector::CombiningKind kind) {
switch (kind) {
case vector::CombiningKind::ADD:
return "add";
case vector::CombiningKind::MUL:
return "mul";
case vector::CombiningKind::MIN:
return "min";
case vector::CombiningKind::MAX:
return "max";
case vector::CombiningKind::AND:
return "and";
case vector::CombiningKind::OR:
return "or";
case vector::CombiningKind::XOR:
return "xor";
}
};
for (int i = 0; i < outerDim; ++i) {
auto v = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
auto reducedValue = rewriter.create<vector::ReductionOp>(
loc, multiReductionOp.getDestVectorType().getElementType(),
rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
ValueRange{});
result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
result, i);
}
rewriter.replaceOp(multiReductionOp, result);
return success();
}
};
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool enableIndexOptimizations) {
patterns.add<VectorCreateMaskOpConversion,
@ -3645,3 +3837,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
TransferReadPermutationLowering, TransferOpReduceRank>(
patterns.getContext());
}
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<InnerDimReductionConversion, ReduceMultiDimReductionRank,
TwoDimMultiReductionToReduction>(patterns.getContext());
}

View File

@ -0,0 +1,66 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: func @vector_multi_reduction
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<2xf32>
// CHECK: %[[C0:.+]] = constant 0 : i32
// CHECK: %[[C1:.+]] = constant 1 : i32
// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
}
// CHECK-LABEL: func @vector_reduction_inner
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
// CHECK: %[[FLAT_RESULT_VEC_0:.+]] = constant dense<0> : vector<6xi32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : i32
// CHECK-DAG: %[[C1:.+]] = constant 1 : i32
// CHECK-DAG: %[[C2:.+]] = constant 2 : i32
// CHECK-DAG: %[[C3:.+]] = constant 3 : i32
// CHECK-DAG: %[[C4:.+]] = constant 4 : i32
// CHECK-DAG: %[[C5:.+]] = constant 5 : i32
// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
// CHECK: %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
// CHECK: %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32>
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
// CHECK: %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32>
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
// CHECK: %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32>
// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
// CHECK: %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32>
/// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
// CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
// CHECK: return %[[RESULT]]
func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
return %0 : vector<2x5xf32>
}
// CHECK-LABEL: func @vector_multi_reduction_transposed
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32>
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
// CHECK: return %[[RESULT]]

View File

@ -376,6 +376,19 @@ struct TestVectorTransferLoweringPatterns
}
};
struct TestVectorMultiReductionLoweringPatterns
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorMultiReductionLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestProgressiveVectorToSCFLoweringPatterns
: public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns,
FunctionPass> {
@ -439,6 +452,12 @@ void registerTestVectorConversions() {
PassRegistration<TestProgressiveVectorToSCFLoweringPatterns> transferOpToSCF(
"test-progressive-convert-vector-to-scf",
"Test conversion patterns to progressively lower transfer ops to SCF");
PassRegistration<TestVectorMultiReductionLoweringPatterns>
multiDimReductionOpLoweringPass(
"test-vector-multi-reduction-lowering-patterns",
"Test conversion patterns to lower vector.multi_reduction to other "
"vector ops");
}
} // namespace test
} // namespace mlir