forked from OSchip/llvm-project
[mlir][Linalg] NFC - Make markers use Identifier instead of StringRef
Summary: This removes string ownership worries by putting everything into the context and allows more constructing identifiers programmatically. Reviewers: ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul Tags: #mlir Differential Revision: https://reviews.llvm.org/D81027
This commit is contained in:
parent
c5468253aa
commit
e349fb70a2
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
|
#include "mlir/IR/Identifier.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
|
||||||
|
@ -206,15 +207,16 @@ struct LinalgTransforms {
|
||||||
|
|
||||||
/// Helper class to control common attribute matching and setting behavior.
|
/// Helper class to control common attribute matching and setting behavior.
|
||||||
struct LinalgMarker {
|
struct LinalgMarker {
|
||||||
LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
|
explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
|
||||||
Optional<StringRef> replacement = None);
|
Optional<Identifier> replacement = None);
|
||||||
LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
|
LinalgMarker(LinalgMarker &&) = default;
|
||||||
|
LinalgMarker(const LinalgMarker &) = default;
|
||||||
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
|
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
|
||||||
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
|
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<StringRef, 4> matchDisjunction;
|
SmallVector<Identifier, 4> matchDisjunction;
|
||||||
Optional<StringRef> replacement;
|
Optional<Identifier> replacement;
|
||||||
};
|
};
|
||||||
|
|
||||||
///
|
///
|
||||||
|
|
|
@ -459,8 +459,8 @@ class RewritePatternList<OpTy, OpTypes...> {
|
||||||
public:
|
public:
|
||||||
static void insert(OwningRewritePatternList &patterns,
|
static void insert(OwningRewritePatternList &patterns,
|
||||||
const LinalgTilingOptions &options, MLIRContext *ctx) {
|
const LinalgTilingOptions &options, MLIRContext *ctx) {
|
||||||
patterns.insert<LinalgTilingPattern<OpTy>>(ctx, options,
|
patterns.insert<LinalgTilingPattern<OpTy>>(
|
||||||
LinalgMarker({}, "tiled"));
|
ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
|
||||||
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
|
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -46,15 +46,11 @@ using llvm::dbgs;
|
||||||
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
|
||||||
"__internal_linalg_transform__";
|
"__internal_linalg_transform__";
|
||||||
|
|
||||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
|
||||||
Optional<StringRef> replacement)
|
Optional<Identifier> replacement)
|
||||||
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
|
||||||
replacement(replacement) {}
|
replacement(replacement) {}
|
||||||
|
|
||||||
mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
|
|
||||||
StringRef replacement)
|
|
||||||
: LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
|
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
||||||
Operation *op) const {
|
Operation *op) const {
|
||||||
|
@ -66,12 +62,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
|
||||||
if (matchDisjunction.empty())
|
if (matchDisjunction.empty())
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// 2. Has no marker and matchDisjuntion matches the no-moarker case.
|
// 2. Has no marker but was expecting a marker.
|
||||||
for (auto marker : matchDisjunction)
|
|
||||||
if (marker.empty())
|
|
||||||
return success();
|
|
||||||
|
|
||||||
// 3. Has no marker but was expecting a marker.
|
|
||||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||||
diag << " does not have any marker from list: ";
|
diag << " does not have any marker from list: ";
|
||||||
interleaveComma(matchDisjunction, diag);
|
interleaveComma(matchDisjunction, diag);
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||||
%v: memref<f32>) {
|
%v: memref<f32>) {
|
||||||
linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
|
linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
|
||||||
memref<?xf32, offset: ?, strides: [1]>,
|
memref<?xf32, offset: ?, strides: [1]>,
|
||||||
memref<f32>
|
memref<?xf32, offset: ?, strides: [1]>,
|
||||||
|
memref<f32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @dot
|
// CHECK-LABEL: func @dot
|
||||||
|
@ -35,9 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||||
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
%x: memref<?xf32, offset: ?, strides: [1]>,
|
%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||||
%y: memref<?xf32, offset: ?, strides: [1]>) {
|
%y: memref<?xf32, offset: ?, strides: [1]>) {
|
||||||
linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
linalg.matvec(%A, %x, %y) :
|
||||||
memref<?xf32, offset: ?, strides: [1]>,
|
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
memref<?xf32, offset: ?, strides: [1]>
|
memref<?xf32, offset: ?, strides: [1]>,
|
||||||
|
memref<?xf32, offset: ?, strides: [1]>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @matvec
|
// CHECK-LABEL: func @matvec
|
||||||
|
@ -51,9 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||||
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
|
||||||
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||||
|
memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @matmul
|
// CHECK-LABEL: func @matmul
|
||||||
|
|
|
@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
||||||
LinalgMarker({"MEM", {}}, "L3"));
|
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
||||||
LinalgMarker({"L3"}, "L2"));
|
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||||
LinalgMarker({"L2"}, "L1"));
|
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
||||||
LinalgMarker({"L1"}, "REG"));
|
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
|
||||||
|
|
||||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
||||||
LinalgTilingLoopType::ParallelLoops),
|
LinalgTilingLoopType::ParallelLoops),
|
||||||
LinalgMarker({}, "L1"));
|
LinalgMarker({}, Identifier::get("L1", ctx)));
|
||||||
|
|
||||||
patterns.insert<LinalgTilingPattern<DotOp>>(
|
patterns.insert<LinalgTilingPattern<DotOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes(8000),
|
ctx, LinalgTilingOptions().setTileSizes(8000),
|
||||||
LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
|
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
|
||||||
|
Identifier::get("L3", ctx),
|
||||||
|
Identifier::get("L2", ctx)},
|
||||||
|
Identifier::get("REG", ctx)));
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Linalg tiling and permutation patterns.
|
// Linalg tiling and permutation patterns.
|
||||||
|
@ -95,20 +98,24 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
LinalgTilingOptions()
|
LinalgTilingOptions()
|
||||||
.setTileSizes({2000, 3000, 4000})
|
.setTileSizes({2000, 3000, 4000})
|
||||||
.setInterchange({1, 2, 0}),
|
.setInterchange({1, 2, 0}),
|
||||||
LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
|
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
||||||
|
Identifier::get("L2__with_perm__", ctx)));
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
LinalgTilingOptions()
|
LinalgTilingOptions()
|
||||||
.setTileSizes({200, 300, 400})
|
.setTileSizes({200, 300, 400})
|
||||||
.setInterchange({1, 0, 2}),
|
.setInterchange({1, 0, 2}),
|
||||||
LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
|
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
|
||||||
|
Identifier::get("L1__with_perm__", ctx)));
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||||
LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
|
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
|
||||||
|
Identifier::get("REG__with_perm__", ctx)));
|
||||||
|
|
||||||
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
||||||
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
||||||
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
|
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
||||||
|
Identifier::get("L1__with_perm__", ctx)));
|
||||||
|
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -116,14 +123,16 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
.setTileSizes({16, 8, 4})
|
.setTileSizes({16, 8, 4})
|
||||||
.setInterchange({1, 2, 0})
|
.setInterchange({1, 2, 0})
|
||||||
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
||||||
LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
|
LinalgMarker(Identifier::get("par__with_perm__", ctx),
|
||||||
|
Identifier::get("after_par__with_perm__", ctx)));
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Linalg to loops patterns.
|
// Linalg to loops patterns.
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
patterns.insert<LinalgLoweringPattern<DotOp>>(
|
patterns.insert<LinalgLoweringPattern<DotOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
/*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
|
/*loweringType=*/LinalgLoweringType::Loops,
|
||||||
|
LinalgMarker(Identifier::get("REG", ctx)));
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Linalg to vector contraction patterns.
|
// Linalg to vector contraction patterns.
|
||||||
|
@ -131,7 +140,7 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
|
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
|
||||||
LinalgVectorizationPattern<FillOp>,
|
LinalgVectorizationPattern<FillOp>,
|
||||||
LinalgVectorizationPattern<GenericOp>>(
|
LinalgVectorizationPattern<GenericOp>>(
|
||||||
ctx, LinalgMarker({"VECTORIZE"}));
|
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Linalg generic permutation patterns.
|
// Linalg generic permutation patterns.
|
||||||
|
@ -139,31 +148,34 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
patterns.insert<LinalgInterchangePattern<GenericOp>>(
|
patterns.insert<LinalgInterchangePattern<GenericOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||||
LinalgMarker({}, "PERMUTED"));
|
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
||||||
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
|
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||||
LinalgMarker({}, "PERMUTED"));
|
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Linalg subview operands promotion.
|
// Linalg subview operands promotion.
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||||
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
||||||
LinalgMarker({"_promote_views_"}, "_views_promoted_"));
|
LinalgMarker(Identifier::get("_promote_views_", ctx),
|
||||||
|
Identifier::get("_views_promoted_", ctx)));
|
||||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
LinalgPromotionOptions()
|
LinalgPromotionOptions()
|
||||||
.setOperandsToPromote({0})
|
.setOperandsToPromote({0})
|
||||||
.useFullTileBuffersByDefault(),
|
.useFullTileBuffersByDefault(),
|
||||||
LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
|
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
|
||||||
|
Identifier::get("_first_view_promoted_", ctx)));
|
||||||
patterns.insert<LinalgPromotionPattern<FillOp>>(
|
patterns.insert<LinalgPromotionPattern<FillOp>>(
|
||||||
ctx,
|
ctx,
|
||||||
LinalgPromotionOptions()
|
LinalgPromotionOptions()
|
||||||
.setOperandsToPromote({0})
|
.setOperandsToPromote({0})
|
||||||
.setUseFullTileBuffers({true})
|
.setUseFullTileBuffers({true})
|
||||||
.setAlignment(32),
|
.setAlignment(32),
|
||||||
LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
|
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
|
||||||
|
Identifier::get("_views_aligned_promoted_", ctx)));
|
||||||
|
|
||||||
applyPatternsAndFoldGreedily(funcOp, patterns);
|
applyPatternsAndFoldGreedily(funcOp, patterns);
|
||||||
|
|
||||||
|
@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
|
||||||
static void fillL1TilingAndMatmulToVectorPatterns(
|
static void fillL1TilingAndMatmulToVectorPatterns(
|
||||||
FuncOp funcOp, StringRef startMarker,
|
FuncOp funcOp, StringRef startMarker,
|
||||||
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
|
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
|
||||||
MLIRContext *context = funcOp.getContext();
|
MLIRContext *ctx = funcOp.getContext();
|
||||||
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
|
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
|
||||||
context,
|
ctx,
|
||||||
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
||||||
LinalgMarker({startMarker}, "L1")));
|
LinalgMarker(Identifier::get(startMarker, ctx),
|
||||||
|
Identifier::get("L1", ctx))));
|
||||||
|
|
||||||
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
|
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
|
||||||
context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
|
||||||
LinalgMarker({"L1"}, "VEC")));
|
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
|
||||||
|
|
||||||
patternsVector.emplace_back(
|
patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
|
||||||
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
|
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
|
||||||
patternsVector.back()
|
patternsVector.back()
|
||||||
.insert<LinalgVectorizationPattern<FillOp>,
|
.insert<LinalgVectorizationPattern<FillOp>,
|
||||||
LinalgVectorizationPattern<CopyOp>>(context);
|
LinalgVectorizationPattern<CopyOp>>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillPromotionCallBackPatterns(MLIRContext *context,
|
void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
||||||
OwningRewritePatternList &patterns) {
|
OwningRewritePatternList &patterns) {
|
||||||
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
||||||
context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
||||||
LinalgMarker({"START"}, "PROMOTE"));
|
LinalgMarker(Identifier::get("START", ctx),
|
||||||
|
Identifier::get("PROMOTE", ctx)));
|
||||||
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
||||||
context,
|
ctx,
|
||||||
LinalgPromotionOptions()
|
LinalgPromotionOptions()
|
||||||
.setOperandsToPromote({0, 2})
|
.setOperandsToPromote({0, 2})
|
||||||
.setUseFullTileBuffers({false, false})
|
.setUseFullTileBuffers({false, false})
|
||||||
|
@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
|
||||||
copyCallBackFn(b, src, dst, true);
|
copyCallBackFn(b, src, dst, true);
|
||||||
return success();
|
return success();
|
||||||
}),
|
}),
|
||||||
LinalgMarker({"PROMOTE"}));
|
LinalgMarker(Identifier::get("PROMOTE", ctx)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
|
@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
|
||||||
MLIRContext *ctx = funcOp.getContext();
|
MLIRContext *ctx = funcOp.getContext();
|
||||||
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||||
if (testMatmulToVectorPatterns1dTiling) {
|
if (testMatmulToVectorPatterns1dTiling) {
|
||||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
|
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
|
||||||
|
stage1Patterns);
|
||||||
} else if (testMatmulToVectorPatterns2dTiling) {
|
} else if (testMatmulToVectorPatterns2dTiling) {
|
||||||
stage1Patterns.emplace_back(
|
stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
|
||||||
LinalgTilingPattern<MatmulOp>(ctx,
|
ctx,
|
||||||
LinalgTilingOptions()
|
LinalgTilingOptions()
|
||||||
.setTileSizes({768, 264, 768})
|
.setTileSizes({768, 264, 768})
|
||||||
.setInterchange({1, 2, 0}),
|
.setInterchange({1, 2, 0}),
|
||||||
LinalgMarker({"START"}, "L2")));
|
LinalgMarker(Identifier::get("START", ctx),
|
||||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
|
Identifier::get("L2", ctx))));
|
||||||
|
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
|
||||||
|
stage1Patterns);
|
||||||
}
|
}
|
||||||
OwningRewritePatternList stage2Patterns =
|
OwningRewritePatternList stage2Patterns =
|
||||||
getLinalgTilingCanonicalizationPatterns(ctx);
|
getLinalgTilingCanonicalizationPatterns(ctx);
|
||||||
|
|
Loading…
Reference in New Issue