forked from OSchip/llvm-project
[mlir][Linalg] NFC - Modernize padding pattern
Differential Revision: https://reviews.llvm.org/D116739
This commit is contained in:
parent
43c5e61b55
commit
2c4a56c418
|
@ -688,7 +688,7 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
|
|||
/// Apply the `padding` transformation as a pattern.
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `padding` for more details.
|
||||
struct LinalgPaddingPattern : public RewritePattern {
|
||||
struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
// Entry point to match any LinalgOp OpInterface.
|
||||
LinalgPaddingPattern(
|
||||
MLIRContext *context,
|
||||
|
@ -701,7 +701,7 @@ struct LinalgPaddingPattern : public RewritePattern {
|
|||
LinalgPaddingOptions options = LinalgPaddingOptions(),
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
LogicalResult matchAndRewrite(LinalgOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -489,23 +489,24 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
|
|||
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
|
||||
MLIRContext *context, LinalgPaddingOptions options,
|
||||
LinalgTransformationFilter filter, PatternBenefit benefit)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
|
||||
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
||||
filter(std::move(filter)), options(std::move(options)) {}
|
||||
|
||||
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
|
||||
StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
|
||||
LinalgTransformationFilter filter, PatternBenefit benefit)
|
||||
: RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
|
||||
options(std::move(options)) {}
|
||||
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
||||
filter(std::move(filter)), options(std::move(options)) {
|
||||
this->filter.addFilter([opName](Operation *op) {
|
||||
return success(op->getName().getStringRef() == opName);
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
LinalgOp linalgOp, PatternRewriter &rewriter) const {
|
||||
if (!linalgOp.hasTensorSemantics())
|
||||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
|
||||
// Pad the operation.
|
||||
|
@ -538,7 +539,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
|
|||
}
|
||||
|
||||
// Replace the original operation to pad.
|
||||
rewriter.replaceOp(op, newResults.getValue());
|
||||
rewriter.replaceOp(linalgOp, newResults.getValue());
|
||||
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue