forked from OSchip/llvm-project
[mlir][Linalg] NFC - Modernize APIs and get rid of unnecessary tiling paterns.
Tiling patterns can be reduced to a single pattern by using interface-based patterns. Differential Revision: https://reviews.llvm.org/D116733
This commit is contained in:
parent
75d65293ca
commit
4a661602ef
|
@ -169,9 +169,14 @@ struct TiledLinalgOp {
|
|||
SmallVector<Operation *, 8> loops;
|
||||
SmallVector<Value, 4> tensorResults;
|
||||
};
|
||||
FailureOr<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
|
||||
const LinalgTilingOptions &options);
|
||||
|
||||
/// Peel the loops of a TiledLinalgOp.
|
||||
void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
|
||||
ArrayRef<int64_t> peeledLoops,
|
||||
LinalgTilingLoopType loopType);
|
||||
|
||||
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
|
||||
/// proceeds as follows:
|
||||
/// - Find outer parallel loops in these ops that can be fused.
|
||||
|
@ -594,24 +599,35 @@ struct LinalgTilingOptions {
|
|||
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
|
||||
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Base pattern that applies the tiling transformation specified by `options`.
|
||||
/// Abort and return failure in 2 cases:
|
||||
/// 1. if the tiling specification is invalid and tiling fails to occur.
|
||||
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
|
||||
/// and some operand shape cannot be bounded statically.
|
||||
struct LinalgBaseTilingPattern : public RewritePattern {
|
||||
// Entry point to match any LinalgOp OpInterface.
|
||||
LinalgBaseTilingPattern(
|
||||
///
|
||||
/// Linalg tiling pattern.
|
||||
///
|
||||
/// Apply the `tiling` transformation as a pattern.
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `tiling` for more details.
|
||||
// TODO: TiledOpInterface
|
||||
struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
/// Construct a generic pattern applied to all LinalgOp that verify `f`.
|
||||
LinalgTilingPattern(
|
||||
MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
LinalgTransformationFilter f = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
// Entry point to match a specific Linalg op.
|
||||
LinalgBaseTilingPattern(
|
||||
|
||||
/// Construct a pattern specifically applied to `opName`.
|
||||
LinalgTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
LinalgTransformationFilter f = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
|
||||
TiledLinalgOp &result) const;
|
||||
|
||||
/// `matchAndRewrite` implementation that returns the significant transformed
|
||||
/// pieces of IR.
|
||||
FailureOr<TiledLinalgOp>
|
||||
returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
|
||||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return returningMatchAndRewrite(op, rewriter);
|
||||
}
|
||||
|
||||
private:
|
||||
/// LinalgTransformMarker handles special attribute manipulations.
|
||||
|
@ -620,68 +636,6 @@ private:
|
|||
LinalgTilingOptions options;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
|
||||
/// SFINAE: This constructor can only trigger for concrete ops that have a
|
||||
/// static `getOperationName` method.
|
||||
template <typename ConcreateOpTy = OpTy>
|
||||
LinalgTilingPattern(
|
||||
MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
|
||||
options, filter, benefit) {}
|
||||
|
||||
/// This constructor is available to anyone.
|
||||
LinalgTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TiledLinalgOp tiledLinalgOp;
|
||||
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
|
||||
tiledLinalgOp)))
|
||||
return failure();
|
||||
if (tiledLinalgOp.tensorResults.empty())
|
||||
rewriter.eraseOp(op);
|
||||
else
|
||||
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
|
||||
/// Entry point to match any LinalgOp OpInterface.
|
||||
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
|
||||
LinalgGenericTilingPattern(
|
||||
MLIRContext *context, LinalgTransformationFilter filter,
|
||||
LinalgTilingOptions options = LinalgTilingOptions(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(context, options, filter, benefit) {}
|
||||
/// Entry point to match a specific Linalg op.
|
||||
LinalgGenericTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TiledLinalgOp tiledLinalgOp;
|
||||
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
|
||||
tiledLinalgOp)))
|
||||
return failure();
|
||||
if (tiledLinalgOp.tensorResults.empty())
|
||||
rewriter.eraseOp(op);
|
||||
else
|
||||
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg padding pattern.
|
||||
///
|
||||
|
@ -1395,6 +1349,32 @@ struct ExtractSliceOfPadTensorSwapPattern
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper classes for type list expansion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename... OpTypes>
|
||||
class TilingPatterns;
|
||||
|
||||
template <>
|
||||
class TilingPatterns<> {
|
||||
public:
|
||||
static void insert(RewritePatternSet &patterns,
|
||||
const LinalgTilingOptions &options,
|
||||
const LinalgTransformationFilter &f) {}
|
||||
};
|
||||
|
||||
template <typename OpTy, typename... OpTypes>
|
||||
class TilingPatterns<OpTy, OpTypes...> {
|
||||
public:
|
||||
static void insert(RewritePatternSet &patterns,
|
||||
const LinalgTilingOptions &options,
|
||||
const LinalgTransformationFilter &f) {
|
||||
patterns.add<LinalgTilingPattern>(OpTy::getOperationName(),
|
||||
patterns.getContext(), options, f);
|
||||
TilingPatterns<OpTypes...>::insert(patterns, options, f);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -784,7 +784,9 @@ tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
|
|||
tileSizes[i] = zero;
|
||||
LinalgTilingOptions tileFusedLoopsOptions = options;
|
||||
tileFusedLoopsOptions.setTileSizes(tileSizes);
|
||||
return tileLinalgOp(b, op, tileFusedLoopsOptions);
|
||||
// TODO: Propagate RewriterBase everywhere.
|
||||
IRRewriter rewriter(b);
|
||||
return tileLinalgOp(rewriter, op, tileFusedLoopsOptions);
|
||||
}
|
||||
|
||||
/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
|
||||
|
|
|
@ -283,10 +283,14 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
|
|||
tileInterchange.begin(), tileInterchange.end()))
|
||||
.setTileSizes(tileSizes)
|
||||
.setLoopType(LinalgTilingLoopType::Loops);
|
||||
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
|
||||
|
||||
// TODO: Propagate RewriterBase everywhere.
|
||||
IRRewriter rewriter(b);
|
||||
FailureOr<TiledLinalgOp> tiledRootOp =
|
||||
tileLinalgOp(rewriter, rootOp, tilingOptions);
|
||||
|
||||
// Exit if tiling the root operation fails.
|
||||
if (!tiledRootOp.hasValue())
|
||||
if (failed(tiledRootOp))
|
||||
return failure();
|
||||
|
||||
// Replace all uses of the root operation if it has been tiled before. All
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
|
||||
//===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -93,14 +93,13 @@ struct LinalgStrategyTilePass
|
|||
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
|
||||
return;
|
||||
|
||||
RewritePatternSet tilingPattern(funcOp.getContext());
|
||||
if (!anchorOpName.empty()) {
|
||||
tilingPattern.add<LinalgGenericTilingPattern>(
|
||||
anchorOpName, funcOp.getContext(), options, filter);
|
||||
} else {
|
||||
tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
|
||||
options);
|
||||
}
|
||||
MLIRContext *ctx = funcOp.getContext();
|
||||
RewritePatternSet tilingPattern(ctx);
|
||||
if (!anchorOpName.empty())
|
||||
tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
|
||||
filter);
|
||||
else
|
||||
tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
|
|||
// a map from loop indices of the LinalgOp to the corresponding non-empty range
|
||||
// indices of newly created loops.
|
||||
static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
|
||||
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
||||
makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
|
||||
ValueRange allShapeSizes, ValueRange allTileSizes) {
|
||||
assert(allTileSizes.size() == map.getNumResults());
|
||||
// Apply `map` to get shape sizes in loop order.
|
||||
|
@ -129,7 +129,7 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
|||
// TODO: Investigate whether mixing implicit and explicit indices
|
||||
// does not lead to losing information.
|
||||
static void
|
||||
transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
||||
transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||
SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
|
||||
for (auto &en : enumerate(allIvs)) {
|
||||
|
@ -144,7 +144,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
|
|||
// Insert a tile `source` into the destination tensor `dest`. The position at
|
||||
// which the tile is inserted (as well as size of tile) is taken from a given
|
||||
// ExtractSliceOp `sliceOp`.
|
||||
static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
|
||||
static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
|
||||
tensor::ExtractSliceOp sliceOp, Value source,
|
||||
Value dest) {
|
||||
return b.create<tensor::InsertSliceOp>(
|
||||
|
@ -155,7 +155,7 @@ static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
|
|||
|
||||
template <typename LoopTy>
|
||||
static FailureOr<TiledLinalgOp>
|
||||
tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
||||
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
|
||||
const LinalgTilingOptions &options) {
|
||||
auto nLoops = op.getNumLoops();
|
||||
// Initial tile sizes may be too big, only take the first nLoops.
|
||||
|
@ -216,7 +216,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
|||
LinalgOp res = op;
|
||||
SmallVector<Value, 4> ivs, tensorResults;
|
||||
auto tiledLoopBodyBuilder =
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIvs,
|
||||
[&](OpBuilder &builder, Location loc, ValueRange localIvs,
|
||||
ValueRange operandValuesToUse) -> scf::ValueVector {
|
||||
ivs.assign(localIvs.begin(), localIvs.end());
|
||||
|
||||
|
@ -255,9 +255,12 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
|||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
|
||||
// TODO: Propagate RewriterBase everywhere.
|
||||
IRRewriter rewriter(b);
|
||||
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
tensorResults.push_back(insertSliceIntoTensor(
|
||||
b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
|
||||
tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
|
||||
res->getResult(resultIdx),
|
||||
sliceOp.source()));
|
||||
} else {
|
||||
tensorResults.push_back(res->getResult(resultIdx));
|
||||
}
|
||||
|
@ -299,7 +302,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
|
|||
|
||||
template <typename LoopTy>
|
||||
FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
|
||||
OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
|
||||
RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(op);
|
||||
|
||||
|
@ -321,7 +324,7 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
|
|||
}
|
||||
|
||||
FailureOr<TiledLinalgOp>
|
||||
mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
|
||||
mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
|
||||
const LinalgTilingOptions &options) {
|
||||
switch (options.loopType) {
|
||||
case LinalgTilingLoopType::Loops:
|
||||
|
@ -338,7 +341,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
|
|||
/// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp`
|
||||
/// and `loopNest` are output parameters that return the new (tiled) PadTensorOp
|
||||
/// and the loop nest.
|
||||
static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
|
||||
static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op,
|
||||
PadTensorOp &newPadOp, LoopNest &loopNest,
|
||||
const LinalgTilingOptions &options) {
|
||||
Location loc = op.getLoc();
|
||||
|
@ -384,8 +387,10 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
|
|||
auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
|
||||
assert(sliceOp && "expected ExtractSliceOp");
|
||||
// Insert the tile into the output tensor.
|
||||
// TODO: Propagate RewriterBase everywhere.
|
||||
IRRewriter rewriter(b);
|
||||
Value yieldValue =
|
||||
insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
|
||||
insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
|
||||
return scf::ValueVector({yieldValue});
|
||||
});
|
||||
return success();
|
||||
|
@ -434,31 +439,6 @@ public:
|
|||
CanonicalizationPatternList<OpTypes...>::insert(patterns);
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper classes for type list expansion.
|
||||
template <typename... OpTypes>
|
||||
class RewritePatternList;
|
||||
|
||||
template <>
|
||||
class RewritePatternList<> {
|
||||
public:
|
||||
static void insert(RewritePatternSet &patterns,
|
||||
const LinalgTilingOptions &options) {}
|
||||
};
|
||||
|
||||
template <typename OpTy, typename... OpTypes>
|
||||
class RewritePatternList<OpTy, OpTypes...> {
|
||||
public:
|
||||
static void insert(RewritePatternSet &patterns,
|
||||
const LinalgTilingOptions &options) {
|
||||
auto *ctx = patterns.getContext();
|
||||
patterns.add<LinalgTilingPattern<OpTy>>(
|
||||
ctx, options,
|
||||
LinalgTransformationFilter(ArrayRef<StringAttr>{},
|
||||
StringAttr::get(ctx, "tiled")));
|
||||
RewritePatternList<OpTypes...>::insert(patterns, options);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
RewritePatternSet
|
||||
|
@ -500,11 +480,14 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
|
|||
/// Populate the given list with patterns that apply Linalg tiling.
|
||||
static void insertTilingPatterns(RewritePatternSet &patterns,
|
||||
const LinalgTilingOptions &options) {
|
||||
RewritePatternList<GenericOp,
|
||||
auto *ctx = patterns.getContext();
|
||||
LinalgTransformationFilter f(ArrayRef<StringAttr>{},
|
||||
StringAttr::get(ctx, "tiled"));
|
||||
TilingPatterns<GenericOp,
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
>::insert(patterns, options);
|
||||
patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
|
||||
>::insert(patterns, options, f);
|
||||
patterns.add<PadTensorOpTilingPattern>(ctx, options);
|
||||
}
|
||||
|
||||
static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
|
||||
//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -284,19 +284,6 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
|
|||
return paddedSubviewResults;
|
||||
}
|
||||
|
||||
/// Linalg base tiling pattern.
|
||||
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter, PatternBenefit benefit)
|
||||
: RewritePattern(opName, benefit, context), filter(std::move(filter)),
|
||||
options(std::move(options)) {}
|
||||
|
||||
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
|
||||
MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter filter, PatternBenefit benefit)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
|
||||
filter(std::move(filter)), options(std::move(options)) {}
|
||||
|
||||
/// Try to peel a loop `op` and return the new result.
|
||||
// TODO: Add support for scf.parallel and affine.for loops.
|
||||
static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
|
||||
|
@ -325,14 +312,15 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
|
|||
}
|
||||
|
||||
/// Peel loops after tiling.
|
||||
static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
|
||||
const LinalgTilingOptions &options) {
|
||||
for (int64_t loop : options.peeledLoops) {
|
||||
void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
|
||||
ArrayRef<int64_t> peeledLoops,
|
||||
LinalgTilingLoopType loopType) {
|
||||
for (int64_t loop : peeledLoops) {
|
||||
assert(loop < static_cast<int64_t>(res.loops.size()) &&
|
||||
"requested peeling of non-existing loop");
|
||||
SmallVector<Value, 4> loopResults;
|
||||
Operation *loopOp = res.loops[loop];
|
||||
if (options.loopType == LinalgTilingLoopType::TiledLoops) {
|
||||
if (loopType == LinalgTilingLoopType::TiledLoops) {
|
||||
assert(llvm::all_of(
|
||||
res.loops,
|
||||
[&](Operation *op) { return op == res.loops.front(); }) &&
|
||||
|
@ -352,28 +340,6 @@ static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
|
|||
}
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
||||
Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
|
||||
Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
|
||||
|
||||
if (!res)
|
||||
return failure();
|
||||
// Clear filter to stop recursive pattern application.
|
||||
filter.replaceLinalgTransformationFilter(rewriter, res->op);
|
||||
|
||||
// Peel loops.
|
||||
peelLoops(rewriter, *res, options);
|
||||
|
||||
result = *res;
|
||||
return success();
|
||||
}
|
||||
|
||||
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
|
||||
if (tiledOp.loops.empty())
|
||||
return tiledOp.op.getOperation()->getResults();
|
||||
|
@ -459,9 +425,9 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
|
|||
})) {
|
||||
LinalgTilingOptions unfusedTilingOptions = tilingOptions;
|
||||
unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
|
||||
Optional<TiledLinalgOp> unfusedTiledOp =
|
||||
FailureOr<TiledLinalgOp> unfusedTiledOp =
|
||||
tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
|
||||
if (!unfusedTiledOp)
|
||||
if (failed(unfusedTiledOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(tiledAndFusedOps->op,
|
||||
getTiledOpResult(unfusedTiledOp.getValue()));
|
||||
|
@ -485,6 +451,48 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Linalg tiling pattern.
|
||||
mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
|
||||
MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter f, PatternBenefit benefit)
|
||||
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
||||
filter(std::move(f)), options(std::move(options)) {}
|
||||
|
||||
mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
|
||||
LinalgTransformationFilter f, PatternBenefit benefit)
|
||||
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
||||
filter(std::move(f)), options(std::move(options)) {
|
||||
this->filter.addFilter([opName](Operation *op) {
|
||||
return success(op->getName().getStringRef() == opName);
|
||||
});
|
||||
}
|
||||
|
||||
FailureOr<TiledLinalgOp>
|
||||
mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
|
||||
LinalgOp op, PatternRewriter &rewriter) const {
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
|
||||
FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
|
||||
if (failed(res))
|
||||
return failure();
|
||||
|
||||
// Clear filter to stop recursive pattern application.
|
||||
// This must be done here to properly propagate to peeling branches.
|
||||
filter.replaceLinalgTransformationFilter(rewriter, res->op);
|
||||
|
||||
// Peel the loops of the TiledLinalgOp.
|
||||
peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
|
||||
|
||||
if (res->tensorResults.empty())
|
||||
rewriter.eraseOp(op);
|
||||
else
|
||||
rewriter.replaceOp(op, res->tensorResults);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Linalg padding pattern.
|
||||
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
|
||||
MLIRContext *context, LinalgPaddingOptions options,
|
||||
|
|
|
@ -1178,8 +1178,9 @@ static void populateVectorizationPatterns(
|
|||
|
||||
constexpr static StringRef kTiledMarker = "TILED";
|
||||
constexpr static StringRef kPromotedMarker = "PROMOTED";
|
||||
tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
|
||||
context, LinalgTilingOptions().setTileSizes(tileSizes),
|
||||
tilingPatterns.add<LinalgTilingPattern>(
|
||||
ConvOp::getOperationName(), context,
|
||||
LinalgTilingOptions().setTileSizes(tileSizes),
|
||||
LinalgTransformationFilter(ArrayRef<StringAttr>{},
|
||||
StringAttr::get(kTiledMarker, context)));
|
||||
|
||||
|
|
|
@ -138,32 +138,36 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
//===--------------------------------------------------------------------===//
|
||||
// Linalg tiling patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
|
||||
StringAttr::get(ctx, "L3")));
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
|
||||
StringAttr::get(ctx, "L2")));
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
|
||||
StringAttr::get(ctx, "L1")));
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
|
||||
StringAttr::get(ctx, "REG")));
|
||||
|
||||
patterns.add<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatvecOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops),
|
||||
LinalgTransformationFilter(ArrayRef<StringAttr>{},
|
||||
StringAttr::get(ctx, "L1")));
|
||||
|
||||
patterns.add<LinalgTilingPattern<DotOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes(8000),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
|
||||
LinalgTransformationFilter(
|
||||
ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
|
||||
StringAttr::get(ctx, "L3"),
|
||||
|
@ -173,32 +177,34 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
//===--------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({2000, 3000, 4000})
|
||||
.setInterchange({1, 2, 0}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
|
||||
StringAttr::get(ctx, "L2__with_perm__")));
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({200, 300, 400})
|
||||
.setInterchange({1, 0, 2}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
|
||||
StringAttr::get(ctx, "L1__with_perm__")));
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
|
||||
StringAttr::get(ctx, "REG__with_perm__")));
|
||||
|
||||
patterns.add<LinalgTilingPattern<MatvecOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatvecOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
|
||||
StringAttr::get(ctx, "L1__with_perm__")));
|
||||
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({16, 8, 4})
|
||||
.setInterchange({1, 2, 0})
|
||||
|
@ -274,8 +280,8 @@ static void fillL1TilingAndMatmulToVectorPatterns(
|
|||
SmallVectorImpl<RewritePatternSet> &patternsVector) {
|
||||
MLIRContext *ctx = funcOp.getContext();
|
||||
patternsVector.emplace_back(
|
||||
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
ctx, std::make_unique<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 12, 16})
|
||||
.setInterchange({1, 0, 2}),
|
||||
|
@ -339,8 +345,9 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
|
|||
|
||||
static void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
||||
LinalgTransformationFilter(StringAttr::get(ctx, "START"),
|
||||
StringAttr::get(ctx, "PROMOTE")));
|
||||
patterns.add<LinalgPromotionPattern<MatmulOp>>(
|
||||
|
@ -382,8 +389,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
2, DistributionMethod::CyclicNumProcsEqNumIters);
|
||||
cyclicNprocsEqNiters.procInfo =
|
||||
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
|
@ -399,8 +406,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
2, DistributionMethod::CyclicNumProcsGeNumIters);
|
||||
cyclicNprocsGeNiters.procInfo =
|
||||
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
|
@ -416,8 +423,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
DistributionMethod::Cyclic);
|
||||
cyclicNprocsDefault.procInfo =
|
||||
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
|
@ -433,8 +440,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
DistributionMethod::CyclicNumProcsEqNumIters,
|
||||
DistributionMethod::CyclicNumProcsGeNumIters};
|
||||
cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
|
@ -450,8 +457,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
DistributionMethod::CyclicNumProcsGeNumIters,
|
||||
DistributionMethod::Cyclic};
|
||||
cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
|
@ -468,8 +475,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
DistributionMethod::CyclicNumProcsEqNumIters};
|
||||
cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::ParallelLoops)
|
||||
|
@ -485,8 +492,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
DistributionMethod::Cyclic);
|
||||
cyclicNprocsEqNiters.procInfo =
|
||||
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTilingPattern<MatmulOp>>(
|
||||
context,
|
||||
patterns.add<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setLoopType(LinalgTilingLoopType::Loops)
|
||||
|
@ -507,8 +514,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
|
|||
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
|
||||
} else if (testMatmulToVectorPatterns2dTiling) {
|
||||
stage1Patterns.emplace_back(
|
||||
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
|
||||
ctx,
|
||||
ctx, std::make_unique<LinalgTilingPattern>(
|
||||
MatmulOp::getOperationName(), ctx,
|
||||
LinalgTilingOptions()
|
||||
.setTileSizes({768, 264, 768})
|
||||
.setInterchange({1, 2, 0}),
|
||||
|
@ -589,10 +596,9 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
|
|||
} else {
|
||||
linalgTilingOptions.setTileSizes(tileSizes);
|
||||
}
|
||||
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
|
||||
linalg::LinalgTilingPattern<linalg::GenericOp>>(
|
||||
context, linalgTilingOptions,
|
||||
linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
|
||||
linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
|
||||
TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
|
||||
tilingPattern, linalgTilingOptions, f);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue