[mlir][Linalg] NFC - Modernize APIs and get rid of unnecessary tiling paterns.

Tiling patterns can be reduced to a single pattern by using interface-based patterns.

Differential Revision: https://reviews.llvm.org/D116733
This commit is contained in:
Nicolas Vasilache 2022-01-06 05:58:45 -05:00
parent 75d65293ca
commit 4a661602ef
8 changed files with 201 additions and 218 deletions

View File

@ -169,9 +169,14 @@ struct TiledLinalgOp {
SmallVector<Operation *, 8> loops;
SmallVector<Value, 4> tensorResults;
};
FailureOr<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
const LinalgTilingOptions &options);
/// Peel the loops of a TiledLinalgOp.
void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
ArrayRef<int64_t> peeledLoops,
LinalgTilingLoopType loopType);
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
/// proceeds as follows:
/// - Find outer parallel loops in these ops that can be fused.
@ -594,24 +599,35 @@ struct LinalgTilingOptions {
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// Base pattern that applies the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
/// 1. if the tiling specification is invalid and tiling fails to occur.
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
/// and some operand shape cannot be bounded statically.
struct LinalgBaseTilingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgBaseTilingPattern(
///
/// Linalg tiling pattern.
///
/// Apply the `tiling` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `tiling` for more details.
// TODO: TiledOpInterface
struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `f`.
LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
LinalgBaseTilingPattern(
/// Construct a pattern specifically applied to `opName`.
LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
TiledLinalgOp &result) const;
/// `matchAndRewrite` implementation that returns the significant transformed
/// pieces of IR.
FailureOr<TiledLinalgOp>
returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
private:
/// LinalgTransformMarker handles special attribute manipulations.
@ -620,68 +636,6 @@ private:
LinalgTilingOptions options;
};
template <typename OpTy>
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
/// SFINAE: This constructor can only trigger for concrete ops that have a
/// static `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
options, filter, benefit) {}
/// This constructor is available to anyone.
LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledLinalgOp tiledLinalgOp;
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
tiledLinalgOp)))
return failure();
if (tiledLinalgOp.tensorResults.empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
/// Entry point to match any LinalgOp OpInterface.
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgGenericTilingPattern(
MLIRContext *context, LinalgTransformationFilter filter,
LinalgTilingOptions options = LinalgTilingOptions(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(context, options, filter, benefit) {}
/// Entry point to match a specific Linalg op.
LinalgGenericTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledLinalgOp tiledLinalgOp;
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
tiledLinalgOp)))
return failure();
if (tiledLinalgOp.tensorResults.empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
///
/// Linalg padding pattern.
///
@ -1395,6 +1349,32 @@ struct ExtractSliceOfPadTensorSwapPattern
PatternRewriter &rewriter) const override;
};
//===----------------------------------------------------------------------===//
// Helper classes for type list expansion.
//===----------------------------------------------------------------------===//
template <typename... OpTypes>
class TilingPatterns;
template <>
class TilingPatterns<> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgTilingOptions &options,
const LinalgTransformationFilter &f) {}
};
template <typename OpTy, typename... OpTypes>
class TilingPatterns<OpTy, OpTypes...> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgTilingOptions &options,
const LinalgTransformationFilter &f) {
patterns.add<LinalgTilingPattern>(OpTy::getOperationName(),
patterns.getContext(), options, f);
TilingPatterns<OpTypes...>::insert(patterns, options, f);
}
};
} // namespace linalg
} // namespace mlir

View File

@ -784,7 +784,9 @@ tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
tileSizes[i] = zero;
LinalgTilingOptions tileFusedLoopsOptions = options;
tileFusedLoopsOptions.setTileSizes(tileSizes);
return tileLinalgOp(b, op, tileFusedLoopsOptions);
// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
return tileLinalgOp(rewriter, op, tileFusedLoopsOptions);
}
/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected

View File

@ -283,10 +283,14 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
tileInterchange.begin(), tileInterchange.end()))
.setTileSizes(tileSizes)
.setLoopType(LinalgTilingLoopType::Loops);
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
FailureOr<TiledLinalgOp> tiledRootOp =
tileLinalgOp(rewriter, rootOp, tilingOptions);
// Exit if tiling the root operation fails.
if (!tiledRootOp.hasValue())
if (failed(tiledRootOp))
return failure();
// Replace all uses of the root operation if it has been tiled before. All

View File

@ -1,4 +1,4 @@
//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
//===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -93,14 +93,13 @@ struct LinalgStrategyTilePass
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;
RewritePatternSet tilingPattern(funcOp.getContext());
if (!anchorOpName.empty()) {
tilingPattern.add<LinalgGenericTilingPattern>(
anchorOpName, funcOp.getContext(), options, filter);
} else {
tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
options);
}
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet tilingPattern(ctx);
if (!anchorOpName.empty())
tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
filter);
else
tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}

View File

@ -51,7 +51,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
// a map from loop indices of the LinalgOp to the corresponding non-empty range
// indices of newly created loops.
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
ValueRange allShapeSizes, ValueRange allTileSizes) {
assert(allTileSizes.size() == map.getNumResults());
// Apply `map` to get shape sizes in loop order.
@ -129,7 +129,7 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
// TODO: Investigate whether mixing implicit and explicit indices
// does not lead to losing information.
static void
transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
for (auto &en : enumerate(allIvs)) {
@ -144,7 +144,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.
static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
tensor::ExtractSliceOp sliceOp, Value source,
Value dest) {
return b.create<tensor::InsertSliceOp>(
@ -155,7 +155,7 @@ static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
template <typename LoopTy>
static FailureOr<TiledLinalgOp>
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
const LinalgTilingOptions &options) {
auto nLoops = op.getNumLoops();
// Initial tile sizes may be too big, only take the first nLoops.
@ -216,7 +216,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
LinalgOp res = op;
SmallVector<Value, 4> ivs, tensorResults;
auto tiledLoopBodyBuilder =
[&](OpBuilder &b, Location loc, ValueRange localIvs,
[&](OpBuilder &builder, Location loc, ValueRange localIvs,
ValueRange operandValuesToUse) -> scf::ValueVector {
ivs.assign(localIvs.begin(), localIvs.end());
@ -255,9 +255,12 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
// TODO: use an interface/adaptor to avoid leaking position in
// `tiledOperands`.
Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
tensorResults.push_back(insertSliceIntoTensor(
b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
res->getResult(resultIdx),
sliceOp.source()));
} else {
tensorResults.push_back(res->getResult(resultIdx));
}
@ -299,7 +302,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
template <typename LoopTy>
FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@ -321,7 +324,7 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
}
FailureOr<TiledLinalgOp>
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
const LinalgTilingOptions &options) {
switch (options.loopType) {
case LinalgTilingLoopType::Loops:
@ -338,7 +341,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
/// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp`
/// and `loopNest` are output parameters that return the new (tiled) PadTensorOp
/// and the loop nest.
static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op,
PadTensorOp &newPadOp, LoopNest &loopNest,
const LinalgTilingOptions &options) {
Location loc = op.getLoc();
@ -384,8 +387,10 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
assert(sliceOp && "expected ExtractSliceOp");
// Insert the tile into the output tensor.
// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
Value yieldValue =
insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
return scf::ValueVector({yieldValue});
});
return success();
@ -434,31 +439,6 @@ public:
CanonicalizationPatternList<OpTypes...>::insert(patterns);
}
};
/// Helper classes for type list expansion.
template <typename... OpTypes>
class RewritePatternList;
template <>
class RewritePatternList<> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {}
};
template <typename OpTy, typename... OpTypes>
class RewritePatternList<OpTy, OpTypes...> {
public:
static void insert(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {
auto *ctx = patterns.getContext();
patterns.add<LinalgTilingPattern<OpTy>>(
ctx, options,
LinalgTransformationFilter(ArrayRef<StringAttr>{},
StringAttr::get(ctx, "tiled")));
RewritePatternList<OpTypes...>::insert(patterns, options);
}
};
} // namespace
RewritePatternSet
@ -500,11 +480,14 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
/// Populate the given list with patterns that apply Linalg tiling.
static void insertTilingPatterns(RewritePatternSet &patterns,
const LinalgTilingOptions &options) {
RewritePatternList<GenericOp,
auto *ctx = patterns.getContext();
LinalgTransformationFilter f(ArrayRef<StringAttr>{},
StringAttr::get(ctx, "tiled"));
TilingPatterns<GenericOp,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::insert(patterns, options);
patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
>::insert(patterns, options, f);
patterns.add<PadTensorOpTilingPattern>(ctx, options);
}
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {

View File

@ -1,4 +1,4 @@
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -284,19 +284,6 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
return paddedSubviewResults;
}
/// Linalg base tiling pattern.
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
: RewritePattern(opName, benefit, context), filter(std::move(filter)),
options(std::move(options)) {}
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter, PatternBenefit benefit)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
filter(std::move(filter)), options(std::move(options)) {}
/// Try to peel a loop `op` and return the new result.
// TODO: Add support for scf.parallel and affine.for loops.
static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
@ -325,14 +312,15 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
}
/// Peel loops after tiling.
static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
const LinalgTilingOptions &options) {
for (int64_t loop : options.peeledLoops) {
void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
ArrayRef<int64_t> peeledLoops,
LinalgTilingLoopType loopType) {
for (int64_t loop : peeledLoops) {
assert(loop < static_cast<int64_t>(res.loops.size()) &&
"requested peeling of non-existing loop");
SmallVector<Value, 4> loopResults;
Operation *loopOp = res.loops[loop];
if (options.loopType == LinalgTilingLoopType::TiledLoops) {
if (loopType == LinalgTilingLoopType::TiledLoops) {
assert(llvm::all_of(
res.loops,
[&](Operation *op) { return op == res.loops.front(); }) &&
@ -352,28 +340,6 @@ static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
}
}
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
if (!res)
return failure();
// Clear filter to stop recursive pattern application.
filter.replaceLinalgTransformationFilter(rewriter, res->op);
// Peel loops.
peelLoops(rewriter, *res, options);
result = *res;
return success();
}
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
if (tiledOp.loops.empty())
return tiledOp.op.getOperation()->getResults();
@ -459,9 +425,9 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
})) {
LinalgTilingOptions unfusedTilingOptions = tilingOptions;
unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
Optional<TiledLinalgOp> unfusedTiledOp =
FailureOr<TiledLinalgOp> unfusedTiledOp =
tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
if (!unfusedTiledOp)
if (failed(unfusedTiledOp))
return failure();
rewriter.replaceOp(tiledAndFusedOps->op,
getTiledOpResult(unfusedTiledOp.getValue()));
@ -485,6 +451,48 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
return success();
}
/// Linalg tiling pattern.
mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(std::move(f)), options(std::move(options)) {}
mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter f, PatternBenefit benefit)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
filter(std::move(f)), options(std::move(options)) {
this->filter.addFilter([opName](Operation *op) {
return success(op->getName().getStringRef() == opName);
});
}
FailureOr<TiledLinalgOp>
mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
LinalgOp op, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
if (failed(res))
return failure();
// Clear filter to stop recursive pattern application.
// This must be done here to properly propagate to peeling branches.
filter.replaceLinalgTransformationFilter(rewriter, res->op);
// Peel the loops of the TiledLinalgOp.
peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
if (res->tensorResults.empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, res->tensorResults);
return res;
}
/// Linalg padding pattern.
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
MLIRContext *context, LinalgPaddingOptions options,

View File

@ -1178,8 +1178,9 @@ static void populateVectorizationPatterns(
constexpr static StringRef kTiledMarker = "TILED";
constexpr static StringRef kPromotedMarker = "PROMOTED";
tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
context, LinalgTilingOptions().setTileSizes(tileSizes),
tilingPatterns.add<LinalgTilingPattern>(
ConvOp::getOperationName(), context,
LinalgTilingOptions().setTileSizes(tileSizes),
LinalgTransformationFilter(ArrayRef<StringAttr>{},
StringAttr::get(kTiledMarker, context)));

View File

@ -138,32 +138,36 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg tiling patterns.
//===--------------------------------------------------------------------===//
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
StringAttr::get(ctx, "L3")));
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({200, 300, 400}),
LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
StringAttr::get(ctx, "L2")));
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
StringAttr::get(ctx, "L1")));
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({2, 3, 4}),
LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
StringAttr::get(ctx, "REG")));
patterns.add<LinalgTilingPattern<MatvecOp>>(
ctx,
patterns.add<LinalgTilingPattern>(
MatvecOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
LinalgTransformationFilter(ArrayRef<StringAttr>{},
StringAttr::get(ctx, "L1")));
patterns.add<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
patterns.add<LinalgTilingPattern>(
DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
LinalgTransformationFilter(
ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
StringAttr::get(ctx, "L3"),
@ -173,32 +177,34 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===--------------------------------------------------------------------===//
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
StringAttr::get(ctx, "L2__with_perm__")));
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
StringAttr::get(ctx, "L1__with_perm__")));
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({20, 30, 40}),
LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
StringAttr::get(ctx, "REG__with_perm__")));
patterns.add<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
patterns.add<LinalgTilingPattern>(
MatvecOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
StringAttr::get(ctx, "L1__with_perm__")));
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
@ -274,8 +280,8 @@ static void fillL1TilingAndMatmulToVectorPatterns(
SmallVectorImpl<RewritePatternSet> &patternsVector) {
MLIRContext *ctx = funcOp.getContext();
patternsVector.emplace_back(
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
ctx,
ctx, std::make_unique<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({8, 12, 16})
.setInterchange({1, 0, 2}),
@ -339,8 +345,9 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
static void fillPromotionCallBackPatterns(MLIRContext *ctx,
RewritePatternSet &patterns) {
patterns.add<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgTransformationFilter(StringAttr::get(ctx, "START"),
StringAttr::get(ctx, "PROMOTE")));
patterns.add<LinalgPromotionPattern<MatmulOp>>(
@ -382,8 +389,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
2, DistributionMethod::CyclicNumProcsEqNumIters);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@ -399,8 +406,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
2, DistributionMethod::CyclicNumProcsGeNumIters);
cyclicNprocsGeNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@ -416,8 +423,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::Cyclic);
cyclicNprocsDefault.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@ -433,8 +440,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsEqNumIters,
DistributionMethod::CyclicNumProcsGeNumIters};
cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@ -450,8 +457,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsGeNumIters,
DistributionMethod::Cyclic};
cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@ -468,8 +475,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::CyclicNumProcsEqNumIters};
cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::ParallelLoops)
@ -485,8 +492,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
DistributionMethod::Cyclic);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTilingPattern<MatmulOp>>(
context,
patterns.add<LinalgTilingPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
@ -507,8 +514,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
ctx,
ctx, std::make_unique<LinalgTilingPattern>(
MatmulOp::getOperationName(), ctx,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
.setInterchange({1, 2, 0}),
@ -589,10 +596,9 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
} else {
linalgTilingOptions.setTileSizes(tileSizes);
}
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::GenericOp>>(
context, linalgTilingOptions,
linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
tilingPattern, linalgTilingOptions, f);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}