forked from OSchip/llvm-project
[mlir][Linalg] NFC - Make markers use Identifier instead of StringRef
Summary: This removes string ownership worries by putting everything into the context and allows more constructing identifiers programmatically. Reviewers: ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul Tags: #mlir Differential Revision: https://reviews.llvm.org/D81027
This commit is contained in:
parent
c5468253aa
commit
e349fb70a2
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
|
@ -206,15 +207,16 @@ struct LinalgTransforms {
|
|||
|
||||
/// Helper class to control common attribute matching and setting behavior.
|
||||
struct LinalgMarker {
|
||||
LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
|
||||
Optional<StringRef> replacement = None);
|
||||
LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
|
||||
explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
|
||||
Optional<Identifier> replacement = None);
|
||||
LinalgMarker(LinalgMarker &&) = default;
|
||||
LinalgMarker(const LinalgMarker &) = default;
|
||||
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
|
||||
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
|
||||
|
||||
private:
|
||||
SmallVector<StringRef, 4> matchDisjunction;
|
||||
Optional<StringRef> replacement;
|
||||
SmallVector<Identifier, 4> matchDisjunction;
|
||||
Optional<Identifier> replacement;
|
||||
};
|
||||
|
||||
///
|
||||
|
|
|
@ -459,8 +459,8 @@ class RewritePatternList<OpTy, OpTypes...> {
|
|||
public:
|
||||
static void insert(OwningRewritePatternList &patterns,
|
||||
const LinalgTilingOptions &options, MLIRContext *ctx) {
|
||||
patterns.insert<LinalgTilingPattern<OpTy>>(ctx, options,
|
||||
LinalgMarker({}, "tiled"));
|
||||
patterns.insert<LinalgTilingPattern<OpTy>>(
|
||||
ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
|
||||
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -46,15 +46,11 @@ using llvm::dbgs;
|
|||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||
"__internal_linalg_transform__";
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
||||
Optional<StringRef> replacement)
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
|
||||
Optional<Identifier> replacement)
|
||||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||
replacement(replacement) {}
|
||||
|
||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
||||
StringRef replacement)
|
||||
: LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
||||
Operation *op) const {
|
||||
|
@ -66,12 +62,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
|||
if (matchDisjunction.empty())
|
||||
return success();
|
||||
|
||||
// 2. Has no marker and matchDisjuntion matches the no-moarker case.
|
||||
for (auto marker : matchDisjunction)
|
||||
if (marker.empty())
|
||||
return success();
|
||||
|
||||
// 3. Has no marker but was expecting a marker.
|
||||
// 2. Has no marker but was expecting a marker.
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << " does not have any marker from list: ";
|
||||
interleaveComma(matchDisjunction, diag);
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%v: memref<f32>) {
|
||||
linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>
|
||||
linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot
|
||||
|
@ -35,9 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
|||
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>) {
|
||||
linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>
|
||||
linalg.matvec(%A, %x, %y) :
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matvec
|
||||
|
@ -51,9 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul
|
||||
|
|
|
@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
||||
LinalgMarker({"MEM", {}}, "L3"));
|
||||
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
||||
LinalgMarker({"L3"}, "L2"));
|
||||
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgMarker({"L2"}, "L1"));
|
||||
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
||||
LinalgMarker({"L1"}, "REG"));
|
||||
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgMarker({}, "L1"));
|
||||
LinalgMarker({}, Identifier::get("L1", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<DotOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes(8000),
|
||||
LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
|
||||
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
|
||||
Identifier::get("L3", ctx),
|
||||
Identifier::get("L2", ctx)},
|
||||
Identifier::get("REG", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
|
@ -95,20 +98,24 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
LinalgTilingOptions()
|
||||
.setTileSizes({2000, 3000, 4000})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
|
||||
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
||||
Identifier::get("L2__with_perm__", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({200, 300, 400})
|
||||
.setInterchange({1, 0, 2}),
|
||||
LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
|
||||
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
|
||||
Identifier::get("L1__with_perm__", ctx)));
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
|
||||
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
|
||||
Identifier::get("REG__with_perm__", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
||||
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
|
||||
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
||||
Identifier::get("L1__with_perm__", ctx)));
|
||||
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
|
@ -116,14 +123,16 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
.setTileSizes({16, 8, 4})
|
||||
.setInterchange({1, 2, 0})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
|
||||
LinalgMarker(Identifier::get("par__with_perm__", ctx),
|
||||
Identifier::get("after_par__with_perm__", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg to loops patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgLoweringPattern<DotOp>>(
|
||||
ctx,
|
||||
/*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
|
||||
/*loweringType=*/LinalgLoweringType::Loops,
|
||||
LinalgMarker(Identifier::get("REG", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
|
@ -131,7 +140,7 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
|
||||
LinalgVectorizationPattern<FillOp>,
|
||||
LinalgVectorizationPattern<GenericOp>>(
|
||||
ctx, LinalgMarker({"VECTORIZE"}));
|
||||
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
|
@ -139,31 +148,34 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
patterns.insert<LinalgInterchangePattern<GenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgMarker({}, "PERMUTED"));
|
||||
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
||||
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgMarker({}, "PERMUTED"));
|
||||
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg subview operands promotion.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
||||
LinalgMarker({"_promote_views_"}, "_views_promoted_"));
|
||||
LinalgMarker(Identifier::get("_promote_views_", ctx),
|
||||
Identifier::get("_views_promoted_", ctx)));
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
ctx,
|
||||
LinalgPromotionOptions()
|
||||
.setOperandsToPromote({0})
|
||||
.useFullTileBuffersByDefault(),
|
||||
LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
|
||||
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
|
||||
Identifier::get("_first_view_promoted_", ctx)));
|
||||
patterns.insert<LinalgPromotionPattern<FillOp>>(
|
||||
ctx,
|
||||
LinalgPromotionOptions()
|
||||
.setOperandsToPromote({0})
|
||||
.setUseFullTileBuffers({true})
|
||||
.setAlignment(32),
|
||||
LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
|
||||
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
|
||||
Identifier::get("_views_aligned_promoted_", ctx)));
|
||||
|
||||
applyPatternsAndFoldGreedily(funcOp, patterns);
|
||||
|
||||
|
@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
static void fillL1TilingAndMatmulToVectorPatterns(
|
||||
FuncOp funcOp, StringRef startMarker,
|
||||
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
MLIRContext *ctx = funcOp.getContext();
|
||||
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
|
||||
context,
|
||||
ctx,
|
||||
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
||||
LinalgMarker({startMarker}, "L1")));
|
||||
LinalgMarker(Identifier::get(startMarker, ctx),
|
||||
Identifier::get("L1", ctx))));
|
||||
|
||||
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
|
||||
context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
||||
LinalgMarker({"L1"}, "VEC")));
|
||||
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
||||
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
|
||||
|
||||
patternsVector.emplace_back(
|
||||
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
|
||||
patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
|
||||
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
|
||||
patternsVector.back()
|
||||
.insert<LinalgVectorizationPattern<FillOp>,
|
||||
LinalgVectorizationPattern<CopyOp>>(context);
|
||||
LinalgVectorizationPattern<CopyOp>>(ctx);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
|
|||
return success();
|
||||
}
|
||||
|
||||
void fillPromotionCallBackPatterns(MLIRContext *context,
|
||||
void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||
context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
||||
LinalgMarker({"START"}, "PROMOTE"));
|
||||
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
||||
LinalgMarker(Identifier::get("START", ctx),
|
||||
Identifier::get("PROMOTE", ctx)));
|
||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||
context,
|
||||
ctx,
|
||||
LinalgPromotionOptions()
|
||||
.setOperandsToPromote({0, 2})
|
||||
.setUseFullTileBuffers({false, false})
|
||||
|
@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
|
|||
copyCallBackFn(b, src, dst, true);
|
||||
return success();
|
||||
}),
|
||||
LinalgMarker({"PROMOTE"}));
|
||||
LinalgMarker(Identifier::get("PROMOTE", ctx)));
|
||||
}
|
||||
|
||||
static void
|
||||
|
@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
|
|||
MLIRContext *ctx = funcOp.getContext();
|
||||
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||
if (testMatmulToVectorPatterns1dTiling) {
|
||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
|
||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
|
||||
stage1Patterns);
|
||||
} else if (testMatmulToVectorPatterns2dTiling) {
|
||||
stage1Patterns.emplace_back(
|
||||
LinalgTilingPattern<MatmulOp>(ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({768, 264, 768})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker({"START"}, "L2")));
|
||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
|
||||
stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
|
||||
ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({768, 264, 768})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgMarker(Identifier::get("START", ctx),
|
||||
Identifier::get("L2", ctx))));
|
||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
|
||||
stage1Patterns);
|
||||
}
|
||||
OwningRewritePatternList stage2Patterns =
|
||||
getLinalgTilingCanonicalizationPatterns(ctx);
|
||||
|
|
Loading…
Reference in New Issue