[mlir][Linalg] Introduce a helper function for staged pattern application

Summary:
This revision introduces a helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion:
1. the first stage consists of an OwningRewritePatternList. The RewritePattern in this list are applied once and in order.
2. the second stage consists of a single OwningRewritePattern that is applied greedily until convergence.
3. the third stage consists of applying a lambda, generally used for non-local transformation effects.

This allows creating custom fused transformations where patterns can be ordered and applied at a finer granularity than a sequence of traditional compiler passes.

A test that exercises these behaviors is added.

Differential Revision: https://reviews.llvm.org/D79518
This commit is contained in:
Nicolas Vasilache 2020-05-11 16:05:39 -04:00
parent cd7cb1f4ce
commit d12d05a731
5 changed files with 154 additions and 4 deletions

View File

@ -367,6 +367,23 @@ private:
LinalgLoweringType loweringType;
};
//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
/// Helper function to allow applying rewrite patterns, interleaved with more
/// global transformations, in a staged fashion:
/// 1. the first stage consists of a list of OwningRewritePatternList. Each
/// OwningRewritePatternList in this list is applied once, in order.
/// 2. the second stage consists of a single OwningRewritePattern that is
/// applied greedily until convergence.
/// 3. the third stage consists of applying a lambda, generally used for
/// non-local transformation effects. This allows creating custom fused
/// transformations where patterns can be ordered and applied at a finer
/// granularity than a sequence of traditional compiler passes.
LogicalResult applyStagedPatterns(
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
const OwningRewritePatternList &stage2Patterns,
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
} // namespace linalg
} // namespace mlir

View File

@ -388,6 +388,15 @@ class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
OwningRewritePatternList() = default;
/// Construct a OwningRewritePatternList populated with the pattern `t` of
/// type `T`.
template <typename T>
OwningRewritePatternList(T &&t) {
patterns.emplace_back(std::make_unique<T>(t));
}
PatternListT::iterator begin() { return patterns.begin(); }
PatternListT::iterator end() { return patterns.end(); }
PatternListT::const_iterator begin() const { return patterns.begin(); }
@ -399,12 +408,13 @@ public:
//===--------------------------------------------------------------------===//
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
/// the given arguments.
/// the given arguments. Return a reference to `this` for chaining insertions.
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
OwningRewritePatternList &insert(ConstructorArg &&arg,
ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
@ -412,6 +422,7 @@ public:
using dummy = int[];
(void)dummy{
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
return *this;
}
private:

View File

@ -198,3 +198,24 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
rewriter.eraseOp(op);
return success();
}
LogicalResult mlir::linalg::applyStagedPatterns(
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
const OwningRewritePatternList &stage2Patterns,
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
for (const auto &patterns : stage1Patterns) {
if (!applyPatternsAndFoldGreedily(op, patterns)) {
llvm::dbgs() << "Underlying first stage rewrite did not converge";
return failure();
}
if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
llvm::dbgs() << "Underlying second stage rewrite did not converge";
return failure();
}
if (stage3Lambda) {
if (failed(stage3Lambda(op)))
return failure();
}
}
return success();
}

View File

@ -0,0 +1,34 @@
// TODO: this needs a fix to land before being reactivated.
// RUN: ls
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
return
}
// CHECK-LABEL:func @matmul
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
//
// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
//
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
//
// CHECK: linalg.copy
// CHECK: linalg.copy
// CHECK: linalg.copy
//
// CHECK: vector.contract
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
//
// CHECK: linalg.copy

View File

@ -33,6 +33,18 @@ struct TestLinalgTransforms
Option<bool> testPatterns{*this, "test-patterns",
llvm::cl::desc("Test a mixed set of patterns"),
llvm::cl::init(false)};
Option<bool> testMatmulToVectorPatterns1dTiling{
*this, "test-matmul-to-vector-patterns-tile-1d",
llvm::cl::desc(
"Test a fused pass that applies patterns from matmul to vectors via "
"1-d tiling"),
llvm::cl::init(false)};
Option<bool> testMatmulToVectorPatterns2dTiling{
*this, "test-matmul-to-vector-patterns-tile-2d",
llvm::cl::desc(
"Test a fused pass that applies patterns from matmul to vectors via "
"2-d tiling"),
llvm::cl::init(false)};
};
} // end anonymous namespace
@ -137,10 +149,65 @@ static void applyPatterns(FuncOp funcOp) {
});
}
OwningRewritePatternList
getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
OwningRewritePatternList patterns;
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
AffineMinOp::getCanonicalizationPatterns(patterns, context);
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
AllocOp::getCanonicalizationPatterns(patterns, context);
SubViewOp::getCanonicalizationPatterns(patterns, context);
ViewOp::getCanonicalizationPatterns(patterns, context);
MatmulOp::getCanonicalizationPatterns(patterns, context);
return patterns;
}
void fillL1TilingAndMatmulToVectorPatterns(
MLIRContext *context, StringRef startMarker,
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
context,
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
LinalgMarker({startMarker}, "L1")));
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
patternsVector.emplace_back(
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
patternsVector.back()
.insert<LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>>(context);
}
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
if (testPatterns)
return applyPatterns(getFunction());
if (testPatterns) {
applyPatterns(getFunction());
} else {
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START",
stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
LinalgTilingPattern<MatmulOp>(&getContext(),
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
LinalgMarker({"START"}, "L2")));
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2",
stage1Patterns);
}
OwningRewritePatternList stage2Patterns =
getMatmulToVectorCanonicalizationPatterns(&getContext());
applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
}
// Drop the marker.
getFunction().walk([](LinalgOp op) {
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
}
namespace mlir {