diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 1deedc1520c9..333214efcc05 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -95,6 +95,20 @@ constexpr StringLiteral ArgConverter::kCastName; // DialectConversionRewriter //===----------------------------------------------------------------------===// +/// This class contains a snapshot of the current conversion rewriter state. +/// This is useful when saving and undoing a set of rewrites. +struct RewriterState { + RewriterState(unsigned numCreatedOperations, unsigned numReplacements) + : numCreatedOperations(numCreatedOperations), + numReplacements(numReplacements) {} + + /// The current number of created operations. + unsigned numCreatedOperations; + + /// The current number of replacements queued. + unsigned numReplacements; +}; + /// This class implements a pattern rewriter for ConversionPattern /// patterns. It automatically performs remapping of replaced operation values. struct DialectConversionRewriter final : public PatternRewriter { @@ -112,6 +126,24 @@ struct DialectConversionRewriter final : public PatternRewriter { : PatternRewriter(region), argConverter(region.getContext()) {} ~DialectConversionRewriter() = default; + /// Return the current state of the rewriter. + RewriterState getCurrentState() { + return RewriterState(createdOps.size(), replacements.size()); + } + + /// Reset the state of the rewriter to a previously saved point. + void resetState(RewriterState state) { + // Reset any replaced operations and undo any saved mappings. + for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) + for (auto *result : repl.op->getResults()) + mapping.erase(result); + replacements.resize(state.numReplacements); + + // Pop all of the newly created operations. + while (createdOps.size() != state.numCreatedOperations) + createdOps.pop_back_val()->erase(); + } + /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. void discardRewrites() { @@ -354,13 +386,10 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, return failure(); } - auto curOpCount = rewriter.createdOps.size(); - auto curReplCount = rewriter.replacements.size(); + RewriterState curState = rewriter.getCurrentState(); auto cleanupFailure = [&] { - // Pop all of the newly created operations and replacements. - while (rewriter.createdOps.size() != curOpCount) - rewriter.createdOps.pop_back_val()->erase(); - rewriter.replacements.resize(curReplCount); + // Reset the rewriter state and pop this pattern. + rewriter.resetState(curState); appliedPatterns.erase(pattern); return failure(); }; @@ -373,11 +402,13 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, } // Recursively legalize each of the new operations. - for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) { - if (succeeded(legalize(rewriter.createdOps[i], rewriter))) - continue; - LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n"); - return cleanupFailure(); + for (unsigned i = curState.numCreatedOperations, + e = rewriter.createdOps.size(); + i != e; ++i) { + if (failed(legalize(rewriter.createdOps[i], rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n"); + return cleanupFailure(); + } } appliedPatterns.erase(pattern);