From fbceee2d63bd47f94d4b2b519e28184dddca90e2 Mon Sep 17 00:00:00 2001 From: Lex Augusteijn Date: Mon, 16 Nov 2020 22:47:42 +0000 Subject: [PATCH] Add an optional argument for pattern rewriter max iteration count (NFC) Some rewriters take more iterations to converge, add a parameter to overwrite the built-in maximum iteration count. Fix PR48073. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D91553 --- .../Transforms/GreedyPatternRewriteDriver.h | 26 ++++++++++++++++--- .../Utils/GreedyPatternRewriteDriver.cpp | 20 +++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index 9d08ad9fa05d..4a084c57b6ed 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -24,20 +24,38 @@ namespace mlir { /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. -/// Note: This does not apply patterns to the top-level operation itself. Note: +/// work-list driven manner. +/// This variant may stop after a predefined number of iterations, see the +/// alternative below to provide a specific number of iterations before stopping +/// in absence of convergence. +/// Return success if the iterative process converged and no more patterns can +/// be matched in the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. -/// LogicalResult applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns); + +/// Rewrite the regions of the specified operation, with a user-provided limit +/// on iterations to attempt before reaching convergence. +LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternList &patterns, + unsigned maxIterations); + /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns); +/// Rewrite the given regions, with a user-provided limit on iterations to +/// attempt before reaching convergence. +LogicalResult +applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternList &patterns, + unsigned maxIterations); + /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns /// success if no more patterns can be matched. `erased` is set to true if `op` diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index bbe3ac57d91c..170f882c02d0 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -220,12 +220,26 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns); + return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations); +} +LogicalResult +mlir::applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternList &patterns, + unsigned maxIterations) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, + maxIterations); } /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns) { + return applyPatternsAndFoldGreedily(regions, patterns, + maxPatternMatchIterations); +} +LogicalResult +mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternList &patterns, + unsigned maxIterations) { if (regions.empty()) return success(); @@ -241,10 +255,10 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, // Start the pattern driver. GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); - bool converged = driver.simplify(regions, maxPatternMatchIterations); + bool converged = driver.simplify(regions, maxIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << maxPatternMatchIterations << " times"; + << maxIterations << " times"; }); return success(converged); }