[mlir][linalg] Add generalization to CodegenStrategy.

Add a generalization pass and integrate it with CodegenStrategy.

This patch depends on https://reviews.llvm.org/D110728.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110746
This commit is contained in:
Tobias Gysi 2021-10-08 06:06:12 +00:00
parent 4e9dbee1a3
commit 1ebd197bc5
6 changed files with 99 additions and 1 deletions

View File

@ -95,6 +95,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPromotePass(
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
/// Create a LinalgStrategyGeneralizePass.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyGeneralizePass(StringRef opName = "",
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyVectorizePass(StringRef opName = "",

View File

@ -255,6 +255,19 @@ def LinalgStrategyPromotePass
];
}
def LinalgStrategyGeneralizePass
: FunctionPass<"linalg-strategy-generalize-pass"> {
let summary = "Configurable pass to apply pattern-based generalization.";
let constructor = "mlir::createLinalgStrategyGeneralizePass()";
let dependentDialects = ["linalg::LinalgDialect"];
let options = [
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
"Which func op is the anchor to latch on.">,
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
"Which linalg op within the func is the anchor to latch on.">,
];
}
def LinalgStrategyVectorizePass
: FunctionPass<"linalg-strategy-vectorize-pass"> {
let summary = "Configurable pass to apply pattern-based linalg vectorization.";

View File

@ -62,6 +62,21 @@ private:
linalg::LinalgPromotionOptions options;
};
/// Represent one application of createLinalgStrategyGeneralizePass.
struct Generalize : public Transformation {
explicit Generalize(StringRef name,
LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(name) {}
void addToPassPipeline(OpPassManager &pm,
LinalgTransformationFilter m) const override {
pm.addPass(createLinalgStrategyGeneralizePass(opName, m));
}
private:
std::string opName;
};
/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
explicit Vectorize(linalg::LinalgVectorizationOptions options,
@ -117,6 +132,21 @@ struct CodegenStrategy {
return b ? promote(opName, options, f) : *this;
return *this;
}
/// Append a pattern to generalize named operations.
CodegenStrategy &
generalize(StringRef opName,
LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Generalize>(opName, f));
return *this;
}
/// Conditionally append a pattern to generalize named operations.
CodegenStrategy &
generalizeIf(bool b, StringRef opName,
LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? generalize(opName, f) : *this;
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,

View File

@ -68,6 +68,39 @@ struct LinalgStrategyTilePass
LinalgTransformationFilter filter;
};
/// Configurable pass to apply pattern-based linalg generalization.
struct LinalgStrategyGeneralizePass
: public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
LinalgStrategyGeneralizePass() = default;
LinalgStrategyGeneralizePass(StringRef opName,
LinalgTransformationFilter filter)
: filter(filter) {
this->anchorOpName.setValue(opName.str());
}
void runOnFunction() override {
auto funcOp = getFunction();
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;
RewritePatternSet generalizationPattern(funcOp.getContext());
if (!anchorOpName.empty()) {
generalizationPattern.add<LinalgGeneralizationPattern>(
anchorOpName, funcOp.getContext(), filter);
} else {
generalizationPattern.add<LinalgGeneralizationPattern>(
funcOp.getContext(), filter);
}
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(generalizationPattern))))
signalPassFailure();
}
LinalgTransformationFilter filter;
};
/// Configurable pass to apply pattern-based linalg promotion.
struct LinalgStrategyPromotePass
: public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
@ -233,6 +266,13 @@ mlir::createLinalgStrategyPromotePass(StringRef opName,
return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
}
/// Create a LinalgStrategyGeneralizePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyGeneralizePass(StringRef opName,
LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}
/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyVectorizePass(StringRef opName,

View File

@ -4,9 +4,12 @@
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER
// CHECK-LABEL: func @matmul(
// OUTER-LABEL: func @matmul(
// GENER-LABEL: func @matmul(
func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
linalg.matmul
ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
@ -17,6 +20,7 @@ func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<15
// CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
// OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
// GENER: linalg.generic
return
}

View File

@ -86,6 +86,9 @@ struct TestLinalgCodegenStrategy
*this, "register-promote-full-tile-pad",
llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
llvm::cl::init(false)};
Option<bool> generalize{*this, "generalize",
llvm::cl::desc("Generalize named operations."),
llvm::cl::init(false)};
Option<bool> vectorize{
*this, "vectorize",
llvm::cl::desc("Rewrite the linalg op as a vector operation."),
@ -133,6 +136,7 @@ void TestLinalgCodegenStrategy::runStrategy(
vector::VectorTransferSplit vectorTransferSplit) {
assert(!anchorOpName.empty());
CodegenStrategy strategy;
StringRef genericOpName = GenericOp::getOperationName();
strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
.promoteIf(promote, anchorOpName,
LinalgPromotionOptions()
@ -143,7 +147,8 @@ void TestLinalgCodegenStrategy::runStrategy(
LinalgPromotionOptions()
.setAlignment(16)
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
.vectorizeIf(vectorize, anchorOpName)
.generalizeIf(generalize, anchorOpName)
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
.setEnableVectorTransferPartialRewrite(true)
.setEnableVectorContractLowering(true)
.setEnableVectorToSCFConversion(true)