forked from OSchip/llvm-project
When cleaning up after a failed legalization pattern, make sure to remove any newly created value mappings.
PiperOrigin-RevId: 251658984
This commit is contained in:
parent
08d407f243
commit
0d2492eb2e
|
@ -95,6 +95,20 @@ constexpr StringLiteral ArgConverter::kCastName;
|
|||
// DialectConversionRewriter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class contains a snapshot of the current conversion rewriter state.
|
||||
/// This is useful when saving and undoing a set of rewrites.
|
||||
struct RewriterState {
|
||||
RewriterState(unsigned numCreatedOperations, unsigned numReplacements)
|
||||
: numCreatedOperations(numCreatedOperations),
|
||||
numReplacements(numReplacements) {}
|
||||
|
||||
/// The current number of created operations.
|
||||
unsigned numCreatedOperations;
|
||||
|
||||
/// The current number of replacements queued.
|
||||
unsigned numReplacements;
|
||||
};
|
||||
|
||||
/// This class implements a pattern rewriter for ConversionPattern
|
||||
/// patterns. It automatically performs remapping of replaced operation values.
|
||||
struct DialectConversionRewriter final : public PatternRewriter {
|
||||
|
@ -112,6 +126,24 @@ struct DialectConversionRewriter final : public PatternRewriter {
|
|||
: PatternRewriter(region), argConverter(region.getContext()) {}
|
||||
~DialectConversionRewriter() = default;
|
||||
|
||||
/// Return the current state of the rewriter.
|
||||
RewriterState getCurrentState() {
|
||||
return RewriterState(createdOps.size(), replacements.size());
|
||||
}
|
||||
|
||||
/// Reset the state of the rewriter to a previously saved point.
|
||||
void resetState(RewriterState state) {
|
||||
// Reset any replaced operations and undo any saved mappings.
|
||||
for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
|
||||
for (auto *result : repl.op->getResults())
|
||||
mapping.erase(result);
|
||||
replacements.resize(state.numReplacements);
|
||||
|
||||
// Pop all of the newly created operations.
|
||||
while (createdOps.size() != state.numCreatedOperations)
|
||||
createdOps.pop_back_val()->erase();
|
||||
}
|
||||
|
||||
/// Cleanup and destroy any generated rewrite operations. This method is
|
||||
/// invoked when the conversion process fails.
|
||||
void discardRewrites() {
|
||||
|
@ -354,13 +386,10 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
|
|||
return failure();
|
||||
}
|
||||
|
||||
auto curOpCount = rewriter.createdOps.size();
|
||||
auto curReplCount = rewriter.replacements.size();
|
||||
RewriterState curState = rewriter.getCurrentState();
|
||||
auto cleanupFailure = [&] {
|
||||
// Pop all of the newly created operations and replacements.
|
||||
while (rewriter.createdOps.size() != curOpCount)
|
||||
rewriter.createdOps.pop_back_val()->erase();
|
||||
rewriter.replacements.resize(curReplCount);
|
||||
// Reset the rewriter state and pop this pattern.
|
||||
rewriter.resetState(curState);
|
||||
appliedPatterns.erase(pattern);
|
||||
return failure();
|
||||
};
|
||||
|
@ -373,11 +402,13 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
|
|||
}
|
||||
|
||||
// Recursively legalize each of the new operations.
|
||||
for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) {
|
||||
if (succeeded(legalize(rewriter.createdOps[i], rewriter)))
|
||||
continue;
|
||||
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
|
||||
return cleanupFailure();
|
||||
for (unsigned i = curState.numCreatedOperations,
|
||||
e = rewriter.createdOps.size();
|
||||
i != e; ++i) {
|
||||
if (failed(legalize(rewriter.createdOps[i], rewriter))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
|
||||
return cleanupFailure();
|
||||
}
|
||||
}
|
||||
|
||||
appliedPatterns.erase(pattern);
|
||||
|
|
Loading…
Reference in New Issue