[mlir][Linalg] NFC - Expose more options to the CodegenStrategy

This commit is contained in:
Nicolas Vasilache 2021-02-19 14:00:18 +00:00
parent b6db47d7e0
commit 62f5c46eec
2 changed files with 65 additions and 27 deletions

View File

@ -66,8 +66,7 @@ void enqueue(OwningRewritePatternList &patternList, OptionsType options,
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Tile<LinalgOpType>`with the appropriate `options`.
template <typename LinalgOpType>
struct Tile : public Transformation {
template <typename LinalgOpType> struct Tile : public Transformation {
explicit Tile(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(LinalgOpType::getOperationName()),
@ -93,8 +92,7 @@ private:
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Promote<LinalgOpType>`with the appropriate `options`.
template <typename LinalgOpType>
struct Promote : public Transformation {
template <typename LinalgOpType> struct Promote : public Transformation {
explicit Promote(
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@ -150,6 +148,16 @@ private:
linalg::LinalgVectorizationOptions options;
};
/// Options to control the application of late transformations.
struct LateCodegenStrategyOptions {
bool enableLICM = true;
bool enableHoistRedundantVectorTransfers = true;
bool enableHoistRedundantVectorTransfersOnTensor = true;
bool enableVectorTransferPartialRewrite = true;
bool enableVectorContractLowering = true;
bool enableVectorToSCFConversion = true;
};
/// Codegen strategy controls how a Linalg op is progressively lowered.
/// The application uses a 3-level staged patterns strategy which allows
/// ordering transformations by using the Linalg `applyStagedPatterns`
@ -283,10 +291,32 @@ struct CodegenStrategy {
vectorToSCFOptions = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
CodegenStrategy &setHoistInvariantCode(bool enableLICM) {
this->enableLICM = enableLICM;
///
/// Configure the application of late transformations.
///
CodegenStrategy &setEnableLICM(bool val) {
this->lateCodegenStrategyOptions.enableLICM = val;
return *this;
}
CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) {
this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val;
return *this;
}
CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) {
this->lateCodegenStrategyOptions
.enableHoistRedundantVectorTransfersOnTensor = val;
return *this;
}
CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) {
this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val;
return *this;
}
CodegenStrategy &setEnableVectorContractLowering(bool val) {
this->lateCodegenStrategyOptions.enableVectorContractLowering = val;
return *this;
}
CodegenStrategy &setEnableVectorToSCFConversion(bool val) {
this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val;
return *this;
}
@ -300,7 +330,7 @@ private:
vector::VectorTransformsOptions vectorTransformsOptions;
VectorTransferToSCFOptions vectorToSCFOptions;
SmallVector<std::unique_ptr<Transformation>, 4> transformationSequence;
bool enableLICM = true;
LateCodegenStrategyOptions lateCodegenStrategyOptions;
};
} // namespace linalg

View File

@ -53,7 +53,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Some of these may be too aggressive as a stage 3 that is applied on each
// stage 1 application and may have to be split out to post staged patterns
// application (in which case they could just be passes, TBD).
if (enableLICM) {
if (lateCodegenStrategyOptions.enableLICM) {
op->walk([&](LoopLikeOpInterface loopLike) {
LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n"));
if (failed(moveLoopInvariantCode(loopLike)))
@ -62,8 +62,10 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
}
promoteSingleIterationLoops(cast<FuncOp>(op));
hoistViewAllocOps(cast<FuncOp>(op));
hoistRedundantVectorTransfers(cast<FuncOp>(op));
hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers)
hoistRedundantVectorTransfers(cast<FuncOp>(op));
if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfersOnTensor)
hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
return success();
};
(void)linalg::applyStagedPatterns(
@ -74,25 +76,31 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
//===--------------------------------------------------------------------===//
// Programmatic splitting of slow/fast path vector transfers.
OwningRewritePatternList patterns;
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
OwningRewritePatternList patterns;
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
// Programmatic controlled lowering of vector.contract only.
OwningRewritePatternList vectorContractLoweringPatterns;
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
(void)applyPatternsAndFoldGreedily(func,
std::move(vectorContractLoweringPatterns));
if (lateCodegenStrategyOptions.enableVectorContractLowering) {
OwningRewritePatternList vectorContractLoweringPatterns;
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
(void)applyPatternsAndFoldGreedily(
func, std::move(vectorContractLoweringPatterns));
}
// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vectorToSCFOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vectorToSCFOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
}
// Ensure we drop the marker in the end.
func.walk([](LinalgOp op) {