forked from OSchip/llvm-project
[mlir][Linalg] Introduce a helper function for staged pattern application
Summary: This revision introduces a helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion: 1. the first stage consists of an OwningRewritePatternList. The RewritePattern in this list are applied once and in order. 2. the second stage consists of a single OwningRewritePattern that is applied greedily until convergence. 3. the third stage consists of applying a lambda, generally used for non-local transformation effects. This allows creating custom fused transformations where patterns can be ordered and applied at a finer granularity than a sequence of traditional compiler passes. A test that exercises these behaviors is added. Differential Revision: https://reviews.llvm.org/D79518
This commit is contained in:
parent
cd7cb1f4ce
commit
d12d05a731
|
@ -367,6 +367,23 @@ private:
|
|||
LinalgLoweringType loweringType;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support for staged pattern application.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Helper function to allow applying rewrite patterns, interleaved with more
|
||||
/// global transformations, in a staged fashion:
|
||||
/// 1. the first stage consists of a list of OwningRewritePatternList. Each
|
||||
/// OwningRewritePatternList in this list is applied once, in order.
|
||||
/// 2. the second stage consists of a single OwningRewritePattern that is
|
||||
/// applied greedily until convergence.
|
||||
/// 3. the third stage consists of applying a lambda, generally used for
|
||||
/// non-local transformation effects. This allows creating custom fused
|
||||
/// transformations where patterns can be ordered and applied at a finer
|
||||
/// granularity than a sequence of traditional compiler passes.
|
||||
LogicalResult applyStagedPatterns(
|
||||
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
|
||||
const OwningRewritePatternList &stage2Patterns,
|
||||
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -388,6 +388,15 @@ class OwningRewritePatternList {
|
|||
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
|
||||
public:
|
||||
OwningRewritePatternList() = default;
|
||||
|
||||
/// Construct a OwningRewritePatternList populated with the pattern `t` of
|
||||
/// type `T`.
|
||||
template <typename T>
|
||||
OwningRewritePatternList(T &&t) {
|
||||
patterns.emplace_back(std::make_unique<T>(t));
|
||||
}
|
||||
|
||||
PatternListT::iterator begin() { return patterns.begin(); }
|
||||
PatternListT::iterator end() { return patterns.end(); }
|
||||
PatternListT::const_iterator begin() const { return patterns.begin(); }
|
||||
|
@ -399,12 +408,13 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
|
||||
/// the given arguments.
|
||||
/// the given arguments. Return a reference to `this` for chaining insertions.
|
||||
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
|
||||
template <typename... Ts, typename ConstructorArg,
|
||||
typename... ConstructorArgs,
|
||||
typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
|
||||
OwningRewritePatternList &insert(ConstructorArg &&arg,
|
||||
ConstructorArgs &&... args) {
|
||||
// The following expands a call to emplace_back for each of the pattern
|
||||
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
|
@ -412,6 +422,7 @@ public:
|
|||
using dummy = int[];
|
||||
(void)dummy{
|
||||
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -198,3 +198,24 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
|||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::applyStagedPatterns(
|
||||
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
|
||||
const OwningRewritePatternList &stage2Patterns,
|
||||
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
|
||||
for (const auto &patterns : stage1Patterns) {
|
||||
if (!applyPatternsAndFoldGreedily(op, patterns)) {
|
||||
llvm::dbgs() << "Underlying first stage rewrite did not converge";
|
||||
return failure();
|
||||
}
|
||||
if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
|
||||
llvm::dbgs() << "Underlying second stage rewrite did not converge";
|
||||
return failure();
|
||||
}
|
||||
if (stage3Lambda) {
|
||||
if (failed(stage3Lambda(op)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
// TODO: this needs a fix to land before being reactivated.
|
||||
// RUN: ls
|
||||
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
|
||||
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
|
||||
|
||||
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
|
||||
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
|
||||
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL:func @matmul
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
|
||||
// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
|
||||
//
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
|
||||
// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
|
||||
//
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
|
||||
// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
|
||||
//
|
||||
// CHECK: linalg.copy
|
||||
// CHECK: linalg.copy
|
||||
// CHECK: linalg.copy
|
||||
//
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
|
||||
//
|
||||
// CHECK: linalg.copy
|
|
@ -33,6 +33,18 @@ struct TestLinalgTransforms
|
|||
Option<bool> testPatterns{*this, "test-patterns",
|
||||
llvm::cl::desc("Test a mixed set of patterns"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testMatmulToVectorPatterns1dTiling{
|
||||
*this, "test-matmul-to-vector-patterns-tile-1d",
|
||||
llvm::cl::desc(
|
||||
"Test a fused pass that applies patterns from matmul to vectors via "
|
||||
"1-d tiling"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testMatmulToVectorPatterns2dTiling{
|
||||
*this, "test-matmul-to-vector-patterns-tile-2d",
|
||||
llvm::cl::desc(
|
||||
"Test a fused pass that applies patterns from matmul to vectors via "
|
||||
"2-d tiling"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -137,10 +149,65 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
});
|
||||
}
|
||||
|
||||
OwningRewritePatternList
|
||||
getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
|
||||
OwningRewritePatternList patterns;
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineMinOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
|
||||
AllocOp::getCanonicalizationPatterns(patterns, context);
|
||||
SubViewOp::getCanonicalizationPatterns(patterns, context);
|
||||
ViewOp::getCanonicalizationPatterns(patterns, context);
|
||||
MatmulOp::getCanonicalizationPatterns(patterns, context);
|
||||
return patterns;
|
||||
}
|
||||
|
||||
void fillL1TilingAndMatmulToVectorPatterns(
|
||||
MLIRContext *context, StringRef startMarker,
|
||||
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
|
||||
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
|
||||
context,
|
||||
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
||||
LinalgMarker({startMarker}, "L1")));
|
||||
|
||||
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
|
||||
context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
|
||||
|
||||
patternsVector.emplace_back(
|
||||
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
|
||||
patternsVector.back()
|
||||
.insert<LinalgVectorizationPattern<FillOp>,
|
||||
LinalgVectorizationPattern<CopyOp>>(context);
|
||||
}
|
||||
|
||||
/// Apply transformations specified as patterns.
|
||||
void TestLinalgTransforms::runOnFunction() {
|
||||
if (testPatterns)
|
||||
return applyPatterns(getFunction());
|
||||
if (testPatterns) {
|
||||
applyPatterns(getFunction());
|
||||
} else {
|
||||
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||
if (testMatmulToVectorPatterns1dTiling) {
|
||||
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START",
|
||||
stage1Patterns);
|
||||
} else if (testMatmulToVectorPatterns2dTiling) {
|
||||
stage1Patterns.emplace_back(
|
||||
LinalgTilingPattern<MatmulOp>(&getContext(),
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({768, 264, 768})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker({"START"}, "L2")));
|
||||
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2",
|
||||
stage1Patterns);
|
||||
}
|
||||
OwningRewritePatternList stage2Patterns =
|
||||
getMatmulToVectorCanonicalizationPatterns(&getContext());
|
||||
applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
|
||||
}
|
||||
|
||||
// Drop the marker.
|
||||
getFunction().walk([](LinalgOp op) {
|
||||
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
||||
});
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
|
Loading…
Reference in New Issue