forked from OSchip/llvm-project
[mlir][Linalg] Introduce a helper function for staged pattern application
Summary: This revision introduces a helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion: 1. the first stage consists of an OwningRewritePatternList. The RewritePattern in this list are applied once and in order. 2. the second stage consists of a single OwningRewritePattern that is applied greedily until convergence. 3. the third stage consists of applying a lambda, generally used for non-local transformation effects. This allows creating custom fused transformations where patterns can be ordered and applied at a finer granularity than a sequence of traditional compiler passes. A test that exercises these behaviors is added. Differential Revision: https://reviews.llvm.org/D79518
This commit is contained in:
parent
cd7cb1f4ce
commit
d12d05a731
|
@ -367,6 +367,23 @@ private:
|
||||||
LinalgLoweringType loweringType;
|
LinalgLoweringType loweringType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Support for staged pattern application.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// Helper function to allow applying rewrite patterns, interleaved with more
|
||||||
|
/// global transformations, in a staged fashion:
|
||||||
|
/// 1. the first stage consists of a list of OwningRewritePatternList. Each
|
||||||
|
/// OwningRewritePatternList in this list is applied once, in order.
|
||||||
|
/// 2. the second stage consists of a single OwningRewritePattern that is
|
||||||
|
/// applied greedily until convergence.
|
||||||
|
/// 3. the third stage consists of applying a lambda, generally used for
|
||||||
|
/// non-local transformation effects. This allows creating custom fused
|
||||||
|
/// transformations where patterns can be ordered and applied at a finer
|
||||||
|
/// granularity than a sequence of traditional compiler passes.
|
||||||
|
LogicalResult applyStagedPatterns(
|
||||||
|
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
|
||||||
|
const OwningRewritePatternList &stage2Patterns,
|
||||||
|
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
|
||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -388,6 +388,15 @@ class OwningRewritePatternList {
|
||||||
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
OwningRewritePatternList() = default;
|
||||||
|
|
||||||
|
/// Construct a OwningRewritePatternList populated with the pattern `t` of
|
||||||
|
/// type `T`.
|
||||||
|
template <typename T>
|
||||||
|
OwningRewritePatternList(T &&t) {
|
||||||
|
patterns.emplace_back(std::make_unique<T>(t));
|
||||||
|
}
|
||||||
|
|
||||||
PatternListT::iterator begin() { return patterns.begin(); }
|
PatternListT::iterator begin() { return patterns.begin(); }
|
||||||
PatternListT::iterator end() { return patterns.end(); }
|
PatternListT::iterator end() { return patterns.end(); }
|
||||||
PatternListT::const_iterator begin() const { return patterns.begin(); }
|
PatternListT::const_iterator begin() const { return patterns.begin(); }
|
||||||
|
@ -399,12 +408,13 @@ public:
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
|
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
|
||||||
/// the given arguments.
|
/// the given arguments. Return a reference to `this` for chaining insertions.
|
||||||
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
|
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
|
||||||
template <typename... Ts, typename ConstructorArg,
|
template <typename... Ts, typename ConstructorArg,
|
||||||
typename... ConstructorArgs,
|
typename... ConstructorArgs,
|
||||||
typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||||
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
|
OwningRewritePatternList &insert(ConstructorArg &&arg,
|
||||||
|
ConstructorArgs &&... args) {
|
||||||
// The following expands a call to emplace_back for each of the pattern
|
// The following expands a call to emplace_back for each of the pattern
|
||||||
// types 'Ts'. This magic is necessary due to a limitation in the places
|
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||||
// that a parameter pack can be expanded in c++11.
|
// that a parameter pack can be expanded in c++11.
|
||||||
|
@ -412,6 +422,7 @@ public:
|
||||||
using dummy = int[];
|
using dummy = int[];
|
||||||
(void)dummy{
|
(void)dummy{
|
||||||
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
|
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -198,3 +198,24 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::linalg::applyStagedPatterns(
|
||||||
|
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
|
||||||
|
const OwningRewritePatternList &stage2Patterns,
|
||||||
|
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
|
||||||
|
for (const auto &patterns : stage1Patterns) {
|
||||||
|
if (!applyPatternsAndFoldGreedily(op, patterns)) {
|
||||||
|
llvm::dbgs() << "Underlying first stage rewrite did not converge";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
|
||||||
|
llvm::dbgs() << "Underlying second stage rewrite did not converge";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (stage3Lambda) {
|
||||||
|
if (failed(stage3Lambda(op)))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
// TODO: this needs a fix to land before being reactivated.
|
||||||
|
// RUN: ls
|
||||||
|
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
|
||||||
|
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
|
||||||
|
|
||||||
|
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||||
|
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||||
|
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
|
||||||
|
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
|
||||||
|
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||||
|
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||||
|
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL:func @matmul
|
||||||
|
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
|
||||||
|
// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
|
||||||
|
//
|
||||||
|
// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
|
||||||
|
// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
|
||||||
|
//
|
||||||
|
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
|
||||||
|
// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
|
||||||
|
//
|
||||||
|
// CHECK: linalg.copy
|
||||||
|
// CHECK: linalg.copy
|
||||||
|
// CHECK: linalg.copy
|
||||||
|
//
|
||||||
|
// CHECK: vector.contract
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||||
|
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
|
||||||
|
//
|
||||||
|
// CHECK: linalg.copy
|
|
@ -33,6 +33,18 @@ struct TestLinalgTransforms
|
||||||
Option<bool> testPatterns{*this, "test-patterns",
|
Option<bool> testPatterns{*this, "test-patterns",
|
||||||
llvm::cl::desc("Test a mixed set of patterns"),
|
llvm::cl::desc("Test a mixed set of patterns"),
|
||||||
llvm::cl::init(false)};
|
llvm::cl::init(false)};
|
||||||
|
Option<bool> testMatmulToVectorPatterns1dTiling{
|
||||||
|
*this, "test-matmul-to-vector-patterns-tile-1d",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Test a fused pass that applies patterns from matmul to vectors via "
|
||||||
|
"1-d tiling"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
Option<bool> testMatmulToVectorPatterns2dTiling{
|
||||||
|
*this, "test-matmul-to-vector-patterns-tile-2d",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Test a fused pass that applies patterns from matmul to vectors via "
|
||||||
|
"2-d tiling"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
@ -137,10 +149,65 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OwningRewritePatternList
|
||||||
|
getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
AffineMinOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
AllocOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
SubViewOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
ViewOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
MatmulOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
return patterns;
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillL1TilingAndMatmulToVectorPatterns(
|
||||||
|
MLIRContext *context, StringRef startMarker,
|
||||||
|
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
|
||||||
|
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
|
||||||
|
context,
|
||||||
|
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
||||||
|
LinalgMarker({startMarker}, "L1")));
|
||||||
|
|
||||||
|
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
|
||||||
|
context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
|
||||||
|
|
||||||
|
patternsVector.emplace_back(
|
||||||
|
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
|
||||||
|
patternsVector.back()
|
||||||
|
.insert<LinalgVectorizationPattern<FillOp>,
|
||||||
|
LinalgVectorizationPattern<CopyOp>>(context);
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply transformations specified as patterns.
|
/// Apply transformations specified as patterns.
|
||||||
void TestLinalgTransforms::runOnFunction() {
|
void TestLinalgTransforms::runOnFunction() {
|
||||||
if (testPatterns)
|
if (testPatterns) {
|
||||||
return applyPatterns(getFunction());
|
applyPatterns(getFunction());
|
||||||
|
} else {
|
||||||
|
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||||
|
if (testMatmulToVectorPatterns1dTiling) {
|
||||||
|
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START",
|
||||||
|
stage1Patterns);
|
||||||
|
} else if (testMatmulToVectorPatterns2dTiling) {
|
||||||
|
stage1Patterns.emplace_back(
|
||||||
|
LinalgTilingPattern<MatmulOp>(&getContext(),
|
||||||
|
LinalgTilingOptions()
|
||||||
|
.setTileSizes({768, 264, 768})
|
||||||
|
.setInterchange({1, 2, 0}),
|
||||||
|
LinalgMarker({"START"}, "L2")));
|
||||||
|
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2",
|
||||||
|
stage1Patterns);
|
||||||
|
}
|
||||||
|
OwningRewritePatternList stage2Patterns =
|
||||||
|
getMatmulToVectorCanonicalizationPatterns(&getContext());
|
||||||
|
applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop the marker.
|
||||||
|
getFunction().walk([](LinalgOp op) {
|
||||||
|
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
Loading…
Reference in New Issue