forked from OSchip/llvm-project
[mlir][linalg] Add generalization to CodegenStrategy.
Add a generalization pass and integrate it with CodegenStrategy. This patch depends on https://reviews.llvm.org/D110728. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110746
This commit is contained in:
parent
4e9dbee1a3
commit
1ebd197bc5
|
@ -95,6 +95,12 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyPromotePass(
|
|||
linalg::LinalgTransformationFilter filter =
|
||||
linalg::LinalgTransformationFilter());
|
||||
|
||||
/// Create a LinalgStrategyGeneralizePass.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLinalgStrategyGeneralizePass(StringRef opName = "",
|
||||
linalg::LinalgTransformationFilter filter =
|
||||
linalg::LinalgTransformationFilter());
|
||||
|
||||
/// Create a LinalgStrategyVectorizePass.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLinalgStrategyVectorizePass(StringRef opName = "",
|
||||
|
|
|
@ -255,6 +255,19 @@ def LinalgStrategyPromotePass
|
|||
];
|
||||
}
|
||||
|
||||
def LinalgStrategyGeneralizePass
|
||||
: FunctionPass<"linalg-strategy-generalize-pass"> {
|
||||
let summary = "Configurable pass to apply pattern-based generalization.";
|
||||
let constructor = "mlir::createLinalgStrategyGeneralizePass()";
|
||||
let dependentDialects = ["linalg::LinalgDialect"];
|
||||
let options = [
|
||||
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
|
||||
"Which func op is the anchor to latch on.">,
|
||||
Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"",
|
||||
"Which linalg op within the func is the anchor to latch on.">,
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgStrategyVectorizePass
|
||||
: FunctionPass<"linalg-strategy-vectorize-pass"> {
|
||||
let summary = "Configurable pass to apply pattern-based linalg vectorization.";
|
||||
|
|
|
@ -62,6 +62,21 @@ private:
|
|||
linalg::LinalgPromotionOptions options;
|
||||
};
|
||||
|
||||
/// Represent one application of createLinalgStrategyGeneralizePass.
|
||||
struct Generalize : public Transformation {
|
||||
explicit Generalize(StringRef name,
|
||||
LinalgTransformationFilter::FilterFunction f = nullptr)
|
||||
: Transformation(f), opName(name) {}
|
||||
|
||||
void addToPassPipeline(OpPassManager &pm,
|
||||
LinalgTransformationFilter m) const override {
|
||||
pm.addPass(createLinalgStrategyGeneralizePass(opName, m));
|
||||
}
|
||||
|
||||
private:
|
||||
std::string opName;
|
||||
};
|
||||
|
||||
/// Represent one application of createLinalgStrategyVectorizePass.
|
||||
struct Vectorize : public Transformation {
|
||||
explicit Vectorize(linalg::LinalgVectorizationOptions options,
|
||||
|
@ -117,6 +132,21 @@ struct CodegenStrategy {
|
|||
return b ? promote(opName, options, f) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to generalize named operations.
|
||||
CodegenStrategy &
|
||||
generalize(StringRef opName,
|
||||
LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
transformationSequence.emplace_back(
|
||||
std::make_unique<Generalize>(opName, f));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to generalize named operations.
|
||||
CodegenStrategy &
|
||||
generalizeIf(bool b, StringRef opName,
|
||||
LinalgTransformationFilter::FilterFunction f = nullptr) {
|
||||
return b ? generalize(opName, f) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
|
||||
CodegenStrategy &
|
||||
vectorize(StringRef opName,
|
||||
|
|
|
@ -68,6 +68,39 @@ struct LinalgStrategyTilePass
|
|||
LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
/// Configurable pass to apply pattern-based linalg generalization.
|
||||
struct LinalgStrategyGeneralizePass
|
||||
: public LinalgStrategyGeneralizePassBase<LinalgStrategyGeneralizePass> {
|
||||
|
||||
LinalgStrategyGeneralizePass() = default;
|
||||
|
||||
LinalgStrategyGeneralizePass(StringRef opName,
|
||||
LinalgTransformationFilter filter)
|
||||
: filter(filter) {
|
||||
this->anchorOpName.setValue(opName.str());
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
auto funcOp = getFunction();
|
||||
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
|
||||
return;
|
||||
|
||||
RewritePatternSet generalizationPattern(funcOp.getContext());
|
||||
if (!anchorOpName.empty()) {
|
||||
generalizationPattern.add<LinalgGeneralizationPattern>(
|
||||
anchorOpName, funcOp.getContext(), filter);
|
||||
} else {
|
||||
generalizationPattern.add<LinalgGeneralizationPattern>(
|
||||
funcOp.getContext(), filter);
|
||||
}
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp,
|
||||
std::move(generalizationPattern))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
/// Configurable pass to apply pattern-based linalg promotion.
|
||||
struct LinalgStrategyPromotePass
|
||||
: public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
|
||||
|
@ -233,6 +266,13 @@ mlir::createLinalgStrategyPromotePass(StringRef opName,
|
|||
return std::make_unique<LinalgStrategyPromotePass>(opName, opt, filter);
|
||||
}
|
||||
|
||||
/// Create a LinalgStrategyGeneralizePass.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createLinalgStrategyGeneralizePass(StringRef opName,
|
||||
LinalgTransformationFilter filter) {
|
||||
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
|
||||
}
|
||||
|
||||
/// Create a LinalgStrategyVectorizePass.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createLinalgStrategyVectorizePass(StringRef opName,
|
||||
|
|
|
@ -4,9 +4,12 @@
|
|||
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
|
||||
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
|
||||
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER
|
||||
|
||||
|
||||
// CHECK-LABEL: func @matmul(
|
||||
// OUTER-LABEL: func @matmul(
|
||||
// GENER-LABEL: func @matmul(
|
||||
func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
|
||||
linalg.matmul
|
||||
ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
|
||||
|
@ -17,6 +20,7 @@ func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<15
|
|||
// CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
|
||||
|
||||
// OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
|
||||
// GENER: linalg.generic
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -86,6 +86,9 @@ struct TestLinalgCodegenStrategy
|
|||
*this, "register-promote-full-tile-pad",
|
||||
llvm::cl::desc("Pad the small aligned memory buffer to the tile sizes."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> generalize{*this, "generalize",
|
||||
llvm::cl::desc("Generalize named operations."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> vectorize{
|
||||
*this, "vectorize",
|
||||
llvm::cl::desc("Rewrite the linalg op as a vector operation."),
|
||||
|
@ -133,6 +136,7 @@ void TestLinalgCodegenStrategy::runStrategy(
|
|||
vector::VectorTransferSplit vectorTransferSplit) {
|
||||
assert(!anchorOpName.empty());
|
||||
CodegenStrategy strategy;
|
||||
StringRef genericOpName = GenericOp::getOperationName();
|
||||
strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions)
|
||||
.promoteIf(promote, anchorOpName,
|
||||
LinalgPromotionOptions()
|
||||
|
@ -143,7 +147,8 @@ void TestLinalgCodegenStrategy::runStrategy(
|
|||
LinalgPromotionOptions()
|
||||
.setAlignment(16)
|
||||
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
|
||||
.vectorizeIf(vectorize, anchorOpName)
|
||||
.generalizeIf(generalize, anchorOpName)
|
||||
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
|
||||
.setEnableVectorTransferPartialRewrite(true)
|
||||
.setEnableVectorContractLowering(true)
|
||||
.setEnableVectorToSCFConversion(true)
|
||||
|
|
Loading…
Reference in New Issue