[mlir][Linalg] NFC - Make markers use Identifier instead of StringRef

Summary: This removes string ownership worries by putting everything into the context and allows more constructing identifiers programmatically.

Reviewers: ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81027
This commit is contained in:
Nicolas Vasilache 2020-06-02 15:14:32 -04:00
parent c5468253aa
commit e349fb70a2
5 changed files with 80 additions and 67 deletions

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallBitVector.h"
@ -206,15 +207,16 @@ struct LinalgTransforms {
/// Helper class to control common attribute matching and setting behavior.
struct LinalgMarker {
LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
Optional<StringRef> replacement = None);
LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
Optional<Identifier> replacement = None);
LinalgMarker(LinalgMarker &&) = default;
LinalgMarker(const LinalgMarker &) = default;
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
private:
SmallVector<StringRef, 4> matchDisjunction;
Optional<StringRef> replacement;
SmallVector<Identifier, 4> matchDisjunction;
Optional<Identifier> replacement;
};
///

View File

@ -459,8 +459,8 @@ class RewritePatternList<OpTy, OpTypes...> {
public:
static void insert(OwningRewritePatternList &patterns,
const LinalgTilingOptions &options, MLIRContext *ctx) {
patterns.insert<LinalgTilingPattern<OpTy>>(ctx, options,
LinalgMarker({}, "tiled"));
patterns.insert<LinalgTilingPattern<OpTy>>(
ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
}
};

View File

@ -46,15 +46,11 @@ using llvm::dbgs;
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
Optional<StringRef> replacement)
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
Optional<Identifier> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement) {}
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
StringRef replacement)
: LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
LogicalResult
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
Operation *op) const {
@ -66,12 +62,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
if (matchDisjunction.empty())
return success();
// 2. Has no marker and matchDisjuntion matches the no-moarker case.
for (auto marker : matchDisjunction)
if (marker.empty())
return success();
// 3. Has no marker but was expecting a marker.
// 2. Has no marker but was expecting a marker.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any marker from list: ";
interleaveComma(matchDisjunction, diag);

View File

@ -14,9 +14,10 @@
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
%v: memref<f32>) {
linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>,
memref<f32>
linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>,
memref<f32>
return
}
// CHECK-LABEL: func @dot
@ -35,9 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>
linalg.matvec(%A, %x, %y) :
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?xf32, offset: ?, strides: [1]>,
memref<?xf32, offset: ?, strides: [1]>
return
}
// CHECK-LABEL: func @matvec
@ -51,9 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL: func @matmul

View File

@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
LinalgMarker({"MEM", {}}, "L3"));
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
LinalgMarker({"L3"}, "L2"));
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgMarker({"L2"}, "L1"));
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
LinalgMarker({"L1"}, "REG"));
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgMarker({}, "L1"));
LinalgMarker({}, Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
Identifier::get("L3", ctx),
Identifier::get("L2", ctx)},
Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
@ -95,20 +98,24 @@ static void applyPatterns(FuncOp funcOp) {
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
LinalgMarker(Identifier::get("__with_perm__", ctx),
Identifier::get("L2__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
Identifier::get("REG__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
LinalgMarker(Identifier::get("__with_perm__", ctx),
Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
@ -116,14 +123,16 @@ static void applyPatterns(FuncOp funcOp) {
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
LinalgMarker(Identifier::get("par__with_perm__", ctx),
Identifier::get("after_par__with_perm__", ctx)));
//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgLoweringPattern<DotOp>>(
ctx,
/*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
/*loweringType=*/LinalgLoweringType::Loops,
LinalgMarker(Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
@ -131,7 +140,7 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<GenericOp>>(
ctx, LinalgMarker({"VECTORIZE"}));
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
@ -139,31 +148,34 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgMarker({}, "PERMUTED"));
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgMarker({}, "PERMUTED"));
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
//===--------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
LinalgMarker({"_promote_views_"}, "_views_promoted_"));
LinalgMarker(Identifier::get("_promote_views_", ctx),
Identifier::get("_views_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.useFullTileBuffersByDefault(),
LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
Identifier::get("_first_view_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<FillOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.setUseFullTileBuffers({true})
.setAlignment(32),
LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
Identifier::get("_views_aligned_promoted_", ctx)));
applyPatternsAndFoldGreedily(funcOp, patterns);
@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
static void fillL1TilingAndMatmulToVectorPatterns(
FuncOp funcOp, StringRef startMarker,
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
MLIRContext *context = funcOp.getContext();
MLIRContext *ctx = funcOp.getContext();
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
context,
ctx,
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
LinalgMarker({startMarker}, "L1")));
LinalgMarker(Identifier::get(startMarker, ctx),
Identifier::get("L1", ctx))));
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
LinalgMarker({"L1"}, "VEC")));
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
patternsVector.emplace_back(
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
patternsVector.back()
.insert<LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>>(context);
LinalgVectorizationPattern<CopyOp>>(ctx);
}
//===----------------------------------------------------------------------===//
@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
return success();
}
void fillPromotionCallBackPatterns(MLIRContext *context,
void fillPromotionCallBackPatterns(MLIRContext *ctx,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTilingPattern<MatmulOp>>(
context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgMarker({"START"}, "PROMOTE"));
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgMarker(Identifier::get("START", ctx),
Identifier::get("PROMOTE", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
context,
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0, 2})
.setUseFullTileBuffers({false, false})
@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
copyCallBackFn(b, src, dst, true);
return success();
}),
LinalgMarker({"PROMOTE"}));
LinalgMarker(Identifier::get("PROMOTE", ctx)));
}
static void
@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
LinalgTilingPattern<MatmulOp>(ctx,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
LinalgMarker({"START"}, "L2")));
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
ctx,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
LinalgMarker(Identifier::get("START", ctx),
Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
}
OwningRewritePatternList stage2Patterns =
getLinalgTilingCanonicalizationPatterns(ctx);