forked from OSchip/llvm-project
Add patterns to lower vector.multi_reduction into a sequence of vector.reduction
Three patterns are added to convert into vector.multi_reduction into a sequence of vector.reduction as the following: - Transpose the inputs so inner most dimensions are always reduction. - Reduce rank of vector.multi_reduction into 2d with inner most reduction dim (get the 2d canical form) - 2D canonical form is converted into a sequence of vector.reduction. There are two things we might worth in a follow up diff: - An scf.for (maybe optionally) around vector.reduction instead of unrolling it. - Breakdown the vector.reduction into a sequence of vector.reduction (e.g tree-based reduction) instead of relying on how downstream dialects handle it. Note: this will requires passing target-vector-length Differential Revision: https://reviews.llvm.org/D101570
This commit is contained in:
parent
6e6ae6c727
commit
499e89fc91
|
@ -92,6 +92,10 @@ void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
|
|||
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||
bool enableIndexOptimizations);
|
||||
|
||||
// Collect a set of patterns to convert vector.multi_reduction op into
|
||||
// a sequence of vector.reduction ops.
|
||||
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// An attribute that specifies the combining function for `vector.contract`,
|
||||
/// and `vector.reduction`.
|
||||
class CombiningKindAttr
|
||||
|
|
|
@ -3575,6 +3575,198 @@ private:
|
|||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
// Converts vector.multi_reduction into inner-most reduction form by inserting
|
||||
// vector.transpose
|
||||
struct InnerDimReductionConversion
|
||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto src = multiReductionOp.source();
|
||||
auto loc = multiReductionOp.getLoc();
|
||||
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||
|
||||
auto reductionDims = llvm::to_vector<4>(
|
||||
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
|
||||
[](Attribute attr) -> int64_t {
|
||||
return attr.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
llvm::sort(reductionDims);
|
||||
|
||||
int64_t reductionSize = multiReductionOp.reduction_dims().size();
|
||||
|
||||
// Fails if already inner most reduction.
|
||||
bool innerMostReduction = true;
|
||||
for (int i = 0; i < reductionSize; ++i) {
|
||||
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
|
||||
innerMostReduction = false;
|
||||
}
|
||||
}
|
||||
if (innerMostReduction)
|
||||
return failure();
|
||||
|
||||
// Permutes the indices so reduction dims are inner most dims.
|
||||
SmallVector<int64_t> indices;
|
||||
for (int i = 0; i < srcRank; ++i) {
|
||||
indices.push_back(i);
|
||||
}
|
||||
int ir = reductionSize - 1;
|
||||
int id = srcRank - 1;
|
||||
while (ir >= 0) {
|
||||
std::swap(indices[reductionDims[ir--]], indices[id--]);
|
||||
}
|
||||
|
||||
// Sets inner most dims as reduction.
|
||||
SmallVector<bool> reductionMask(srcRank, false);
|
||||
for (int i = 0; i < reductionSize; ++i) {
|
||||
reductionMask[srcRank - i - 1] = true;
|
||||
}
|
||||
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
|
||||
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
|
||||
multiReductionOp, transposeOp.result(), reductionMask,
|
||||
multiReductionOp.kind());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction
|
||||
// dimensions are inner most.
|
||||
struct ReduceMultiDimReductionRank
|
||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
|
||||
if (srcRank == 2)
|
||||
return failure();
|
||||
|
||||
auto loc = multiReductionOp.getLoc();
|
||||
auto reductionDims = llvm::to_vector<4>(
|
||||
llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
|
||||
[](Attribute attr) -> int64_t {
|
||||
return attr.cast<IntegerAttr>().getInt();
|
||||
}));
|
||||
llvm::sort(reductionDims);
|
||||
|
||||
// Fails if not inner most reduction.
|
||||
int64_t reductionSize = reductionDims.size();
|
||||
bool innerMostReduction = true;
|
||||
for (int i = 0; i < reductionSize; ++i) {
|
||||
if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
|
||||
innerMostReduction = false;
|
||||
}
|
||||
}
|
||||
if (!innerMostReduction)
|
||||
return failure();
|
||||
|
||||
// Extracts 2d rank reduction shape.
|
||||
int innerDims = 1;
|
||||
int outterDims = 1;
|
||||
SmallVector<int64_t> innerDimsShape;
|
||||
for (int i = 0; i < srcRank; ++i) {
|
||||
if (i < (srcRank - reductionSize)) {
|
||||
innerDims *= srcShape[i];
|
||||
innerDimsShape.push_back(srcShape[i]);
|
||||
} else {
|
||||
outterDims *= srcShape[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Creates shape cast for the inputs n_d -> 2d
|
||||
auto castedType = VectorType::get(
|
||||
{innerDims, outterDims},
|
||||
multiReductionOp.getSourceVectorType().getElementType());
|
||||
auto castedOp = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, castedType, multiReductionOp.source());
|
||||
|
||||
// Creates the canonical form of 2d vector.multi_reduction with inner most
|
||||
// dim as reduction.
|
||||
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
|
||||
loc, castedOp.result(), ArrayRef<bool>{false, true},
|
||||
multiReductionOp.kind());
|
||||
|
||||
// Creates shape cast for the output 2d -> nd
|
||||
auto outputCastedType = VectorType::get(
|
||||
innerDimsShape,
|
||||
multiReductionOp.getSourceVectorType().getElementType());
|
||||
Value castedOutputOp = rewriter.create<vector::ShapeCastOp>(
|
||||
loc, outputCastedType, newOp.dest());
|
||||
|
||||
rewriter.replaceOp(multiReductionOp, castedOutputOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts 2d vector.multi_reduction with inner most reduction dimension into a
|
||||
// sequence of vector.reduction ops.
|
||||
struct TwoDimMultiReductionToReduction
|
||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||
if (srcRank != 2)
|
||||
return failure();
|
||||
|
||||
if (multiReductionOp.getReductionMask()[0] ||
|
||||
!multiReductionOp.getReductionMask()[1])
|
||||
return failure();
|
||||
|
||||
auto loc = multiReductionOp.getLoc();
|
||||
|
||||
Value result =
|
||||
multiReductionOp.getDestVectorType().getElementType().isIntOrIndex()
|
||||
? rewriter.create<ConstantOp>(
|
||||
loc, multiReductionOp.getDestVectorType(),
|
||||
DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
|
||||
0))
|
||||
: rewriter.create<ConstantOp>(
|
||||
loc, multiReductionOp.getDestVectorType(),
|
||||
DenseElementsAttr::get(multiReductionOp.getDestVectorType(),
|
||||
0.0f));
|
||||
|
||||
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
|
||||
|
||||
// TODO: Add vector::CombiningKind attribute instead of string to
|
||||
// vector.reduction.
|
||||
auto getKindStr = [](vector::CombiningKind kind) {
|
||||
switch (kind) {
|
||||
case vector::CombiningKind::ADD:
|
||||
return "add";
|
||||
case vector::CombiningKind::MUL:
|
||||
return "mul";
|
||||
case vector::CombiningKind::MIN:
|
||||
return "min";
|
||||
case vector::CombiningKind::MAX:
|
||||
return "max";
|
||||
case vector::CombiningKind::AND:
|
||||
return "and";
|
||||
case vector::CombiningKind::OR:
|
||||
return "or";
|
||||
case vector::CombiningKind::XOR:
|
||||
return "xor";
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < outerDim; ++i) {
|
||||
auto v = rewriter.create<vector::ExtractOp>(
|
||||
loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
|
||||
auto reducedValue = rewriter.create<vector::ReductionOp>(
|
||||
loc, multiReductionOp.getDestVectorType().getElementType(),
|
||||
rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
|
||||
ValueRange{});
|
||||
result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
|
||||
result, i);
|
||||
}
|
||||
rewriter.replaceOp(multiReductionOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void mlir::vector::populateVectorMaskMaterializationPatterns(
|
||||
RewritePatternSet &patterns, bool enableIndexOptimizations) {
|
||||
patterns.add<VectorCreateMaskOpConversion,
|
||||
|
@ -3645,3 +3837,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
|
|||
TransferReadPermutationLowering, TransferOpReduceRank>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<InnerDimReductionConversion, ReduceMultiDimReductionRank,
|
||||
TwoDimMultiReductionToReduction>(patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
|
||||
|
||||
func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<mul>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_multi_reduction
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>
|
||||
// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<2xf32>
|
||||
// CHECK: %[[C0:.+]] = constant 0 : i32
|
||||
// CHECK: %[[C1:.+]] = constant 1 : i32
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0]
|
||||
// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1]
|
||||
// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32
|
||||
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
|
||||
// CHECK: return %[[RESULT_VEC]]
|
||||
|
||||
func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
|
||||
return %0 : vector<2x3xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_reduction_inner
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32>
|
||||
// CHECK: %[[FLAT_RESULT_VEC_0:.+]] = constant dense<0> : vector<6xi32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : i32
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : i32
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : i32
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : i32
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : i32
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : i32
|
||||
// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
|
||||
// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
|
||||
// CHECK: %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32>
|
||||
// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
|
||||
// CHECK: %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32>
|
||||
// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
|
||||
// CHECK: %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32>
|
||||
// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
|
||||
// CHECK: %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32>
|
||||
// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
|
||||
// CHECK: %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32>
|
||||
/// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
|
||||
// CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
|
||||
// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
|
||||
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
|
||||
func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
|
||||
return %0 : vector<2x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_multi_reduction_transposed
|
||||
// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32>
|
||||
// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
|
||||
// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
|
||||
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
|
||||
// CHECK: return %[[RESULT]]
|
|
@ -376,6 +376,19 @@ struct TestVectorTransferLoweringPatterns
|
|||
}
|
||||
};
|
||||
|
||||
struct TestVectorMultiReductionLoweringPatterns
|
||||
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
|
||||
FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorMultiReductionLoweringPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestProgressiveVectorToSCFLoweringPatterns
|
||||
: public PassWrapper<TestProgressiveVectorToSCFLoweringPatterns,
|
||||
FunctionPass> {
|
||||
|
@ -439,6 +452,12 @@ void registerTestVectorConversions() {
|
|||
PassRegistration<TestProgressiveVectorToSCFLoweringPatterns> transferOpToSCF(
|
||||
"test-progressive-convert-vector-to-scf",
|
||||
"Test conversion patterns to progressively lower transfer ops to SCF");
|
||||
|
||||
PassRegistration<TestVectorMultiReductionLoweringPatterns>
|
||||
multiDimReductionOpLoweringPass(
|
||||
"test-vector-multi-reduction-lowering-patterns",
|
||||
"Test conversion patterns to lower vector.multi_reduction to other "
|
||||
"vector ops");
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue