[mlir][linalg] Add loop interchange to CodegenStrategy.

Add a loop interchange pass and integrate it with CodegenStrategy.

This patch depends on https://reviews.llvm.org/D110728 and https://reviews.llvm.org/D110746.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110748
This commit is contained in:
Tobias Gysi 2021-10-08 06:39:22 +00:00
parent b84d9d299e
commit 23800b05be
6 changed files with 96 additions and 3 deletions

View File

@ -101,6 +101,12 @@ createLinalgStrategyGeneralizePass(StringRef opName = "",
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange = {},
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter());
/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>>
createLinalgStrategyVectorizePass(StringRef opName = "",

View File

@ -268,6 +268,17 @@ def LinalgStrategyGeneralizePass
];
}
def LinalgStrategyInterchangePass
: FunctionPass<"linalg-strategy-interchange-pass"> {
let summary = "Configurable pass to apply pattern-based iterator interchange.";
let constructor = "mlir::createLinalgStrategyInterchangePass()";
let dependentDialects = ["linalg::LinalgDialect"];
let options = [
Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
"Which func op is the anchor to latch on.">,
];
}
def LinalgStrategyVectorizePass
: FunctionPass<"linalg-strategy-vectorize-pass"> {
let summary = "Configurable pass to apply pattern-based linalg vectorization.";

View File

@ -77,6 +77,22 @@ private:
std::string opName;
};
/// Represent one application of createLinalgStrategyInterchangePass.
struct Interchange : public Transformation {
explicit Interchange(ArrayRef<int64_t> iteratorInterchange,
LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), iteratorInterchange(iteratorInterchange.begin(),
iteratorInterchange.end()) {}
void addToPassPipeline(OpPassManager &pm,
LinalgTransformationFilter m) const override {
pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m));
}
private:
SmallVector<int64_t> iteratorInterchange;
};
/// Represent one application of createLinalgStrategyVectorizePass.
struct Vectorize : public Transformation {
explicit Vectorize(linalg::LinalgVectorizationOptions options,
@ -147,6 +163,21 @@ struct CodegenStrategy {
return b ? generalize(opName, f) : *this;
return *this;
}
/// Append a pattern to interchange iterators.
CodegenStrategy &
interchange(ArrayRef<int64_t> iteratorInterchange,
LinalgTransformationFilter::FilterFunction f = nullptr) {
transformationSequence.emplace_back(
std::make_unique<Interchange>(iteratorInterchange, f));
return *this;
}
/// Conditionally append a pattern to interchange iterators.
CodegenStrategy &
interchangeIf(bool b, ArrayRef<int64_t> iteratorInterchange,
LinalgTransformationFilter::FilterFunction f = nullptr) {
return b ? interchange(iteratorInterchange, f) : *this;
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
CodegenStrategy &
vectorize(StringRef opName,

View File

@ -101,6 +101,37 @@ struct LinalgStrategyGeneralizePass
LinalgTransformationFilter filter;
};
/// Configurable pass to apply pattern-based linalg generalization.
struct LinalgStrategyInterchangePass
: public LinalgStrategyInterchangePassBase<LinalgStrategyInterchangePass> {
LinalgStrategyInterchangePass() = default;
LinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
LinalgTransformationFilter filter)
: iteratorInterchange(iteratorInterchange.begin(),
iteratorInterchange.end()),
filter(filter) {}
void runOnFunction() override {
auto funcOp = getFunction();
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;
SmallVector<unsigned> interchangeVector(iteratorInterchange.begin(),
iteratorInterchange.end());
RewritePatternSet interchangePattern(funcOp.getContext());
interchangePattern.add<GenericOpInterchangePattern>(
funcOp.getContext(), interchangeVector, filter);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(interchangePattern))))
signalPassFailure();
}
SmallVector<int64_t> iteratorInterchange;
LinalgTransformationFilter filter;
};
/// Configurable pass to apply pattern-based linalg promotion.
struct LinalgStrategyPromotePass
: public LinalgStrategyPromotePassBase<LinalgStrategyPromotePass> {
@ -273,6 +304,14 @@ mlir::createLinalgStrategyGeneralizePass(StringRef opName,
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}
/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyInterchangePass(ArrayRef<int64_t> iteratorInterchange,
LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyInterchangePass>(iteratorInterchange,
filter);
}
/// Create a LinalgStrategyVectorizePass.
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgStrategyVectorizePass(StringRef opName,

View File

@ -4,7 +4,7 @@
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" | FileCheck %s
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" | FileCheck %s --check-prefix=OUTER
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize" | FileCheck %s --check-prefix=GENER
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 generalize iterator-interchange=0,2,1" | FileCheck %s --check-prefix=GENER
// CHECK-LABEL: func @matmul(
@ -19,8 +19,10 @@ func @matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<15
// CHECK-SAME: {lhs_columns = 8 : i32, lhs_rows = 2 : i32, rhs_columns = 4 : i32}
// CHECK-SAME: (vector<16xf32>, vector<32xf32>) -> vector<8xf32>
// OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
// GENER: linalg.generic
// OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32>
// GENER: linalg.generic
// GENER-SAME: iterator_types = ["parallel", "reduction", "parallel"]
return
}

View File

@ -89,6 +89,9 @@ struct TestLinalgCodegenStrategy
Option<bool> generalize{*this, "generalize",
llvm::cl::desc("Generalize named operations."),
llvm::cl::init(false)};
ListOption<int64_t> iteratorInterchange{
*this, "iterator-interchange", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Specifies the iterator interchange.")};
Option<bool> vectorize{
*this, "vectorize",
llvm::cl::desc("Rewrite the linalg op as a vector operation."),
@ -148,6 +151,7 @@ void TestLinalgCodegenStrategy::runStrategy(
.setAlignment(16)
.setUseFullTileBuffersByDefault(registerPromoteFullTile))
.generalizeIf(generalize, anchorOpName)
.interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)
.vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName)
.setEnableVectorTransferPartialRewrite(true)
.setEnableVectorContractLowering(true)