When cleaning up after a failed legalization pattern, make sure to remove any newly created value mappings.

PiperOrigin-RevId: 251658984
This commit is contained in:
River Riddle 2019-06-05 09:36:32 -07:00 committed by Mehdi Amini
parent 08d407f243
commit 0d2492eb2e
1 changed files with 42 additions and 11 deletions

View File

@ -95,6 +95,20 @@ constexpr StringLiteral ArgConverter::kCastName;
// DialectConversionRewriter
//===----------------------------------------------------------------------===//
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOperations, unsigned numReplacements)
: numCreatedOperations(numCreatedOperations),
numReplacements(numReplacements) {}
/// The current number of created operations.
unsigned numCreatedOperations;
/// The current number of replacements queued.
unsigned numReplacements;
};
/// This class implements a pattern rewriter for ConversionPattern
/// patterns. It automatically performs remapping of replaced operation values.
struct DialectConversionRewriter final : public PatternRewriter {
@ -112,6 +126,24 @@ struct DialectConversionRewriter final : public PatternRewriter {
: PatternRewriter(region), argConverter(region.getContext()) {}
~DialectConversionRewriter() = default;
/// Return the current state of the rewriter.
RewriterState getCurrentState() {
return RewriterState(createdOps.size(), replacements.size());
}
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state) {
// Reset any replaced operations and undo any saved mappings.
for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
for (auto *result : repl.op->getResults())
mapping.erase(result);
replacements.resize(state.numReplacements);
// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOperations)
createdOps.pop_back_val()->erase();
}
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
void discardRewrites() {
@ -354,13 +386,10 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
return failure();
}
auto curOpCount = rewriter.createdOps.size();
auto curReplCount = rewriter.replacements.size();
RewriterState curState = rewriter.getCurrentState();
auto cleanupFailure = [&] {
// Pop all of the newly created operations and replacements.
while (rewriter.createdOps.size() != curOpCount)
rewriter.createdOps.pop_back_val()->erase();
rewriter.replacements.resize(curReplCount);
// Reset the rewriter state and pop this pattern.
rewriter.resetState(curState);
appliedPatterns.erase(pattern);
return failure();
};
@ -373,11 +402,13 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
}
// Recursively legalize each of the new operations.
for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) {
if (succeeded(legalize(rewriter.createdOps[i], rewriter)))
continue;
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
return cleanupFailure();
for (unsigned i = curState.numCreatedOperations,
e = rewriter.createdOps.size();
i != e; ++i) {
if (failed(legalize(rewriter.createdOps[i], rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
return cleanupFailure();
}
}
appliedPatterns.erase(pattern);