[mlir][DialectConversion] Add support for properly tracking replaceUsesOfBlockArgument

The current implementation of this method performs the replacement directly, and thus doesn't support proper back tracking.

Differential Revision: https://reviews.llvm.org/D78790
This commit is contained in:
River Riddle 2020-04-24 12:25:05 -07:00
parent 4de60d955a
commit 0816de167a
5 changed files with 97 additions and 17 deletions

View File

@ -145,6 +145,11 @@ public:
replaceAllUsesExcept(Value newValue,
const SmallPtrSetImpl<Operation *> &exceptions) const;
/// Replace all uses of 'this' value with 'newValue' if the given callback
/// returns true.
void replaceUsesWithIf(Value newValue,
function_ref<bool(OpOperand &)> shouldReplace);
//===--------------------------------------------------------------------===//
// Uses

View File

@ -125,6 +125,15 @@ void Value::replaceAllUsesExcept(
}
}
/// Replace all uses of 'this' value with 'newValue' if the given callback
/// returns true.
void Value::replaceUsesWithIf(Value newValue,
function_ref<bool(OpOperand &)> shouldReplace) {
for (OpOperand &use : llvm::make_early_inc_range(getUses()))
if (shouldReplace(use))
use.set(newValue);
}
//===--------------------------------------------------------------------===//
// Uses

View File

@ -197,8 +197,6 @@ struct ArgConverter {
/// Fully replace uses of the old arguments with the new, materializing cast
/// operations as necessary.
// FIXME(riverriddle) The 'mapping' parameter is only necessary because the
// implementation of replaceUsesOfBlockArgument is buggy.
void applyRewrites(ConversionValueMapping &mapping);
//===--------------------------------------------------------------------===//
@ -436,9 +434,10 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numReplacements,
unsigned numBlockActions, unsigned numIgnoredOperations,
unsigned numRootUpdates)
unsigned numArgReplacements, unsigned numBlockActions,
unsigned numIgnoredOperations, unsigned numRootUpdates)
: numCreatedOps(numCreatedOps), numReplacements(numReplacements),
numArgReplacements(numArgReplacements),
numBlockActions(numBlockActions),
numIgnoredOperations(numIgnoredOperations),
numRootUpdates(numRootUpdates) {}
@ -449,6 +448,9 @@ struct RewriterState {
/// The current number of replacements queued.
unsigned numReplacements;
/// The current number of argument replacements queued.
unsigned numArgReplacements;
/// The current number of block actions performed.
unsigned numBlockActions;
@ -624,6 +626,9 @@ struct ConversionPatternRewriterImpl {
/// Ordered vector of any requested operation replacements.
SmallVector<OpReplacement, 4> replacements;
/// Ordered vector of any requested block argument replacements.
SmallVector<BlockArgument, 4> argReplacements;
/// Ordered list of block operations (creations, splits, motions).
SmallVector<BlockAction, 4> blockActions;
@ -654,8 +659,8 @@ struct ConversionPatternRewriterImpl {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
blockActions.size(), ignoredOps.size(),
rootUpdates.size());
argReplacements.size(), blockActions.size(),
ignoredOps.size(), rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@ -664,6 +669,12 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
rootUpdates[i].resetOperation();
rootUpdates.resize(state.numRootUpdates);
// Reset any replaced arguments.
for (BlockArgument replacedArg :
llvm::drop_begin(argReplacements, state.numArgReplacements))
mapping.erase(replacedArg);
argReplacements.resize(state.numArgReplacements);
// Undo any block actions.
undoBlockActions(state.numBlockActions);
@ -753,6 +764,25 @@ void ConversionPatternRewriterImpl::applyRewrites() {
argConverter.notifyOpRemoved(repl.op);
}
// Apply all of the requested argument replacements.
for (BlockArgument arg : argReplacements) {
Value repl = mapping.lookupOrDefault(arg);
if (repl.isa<BlockArgument>()) {
arg.replaceAllUsesWith(repl);
continue;
}
// If the replacement value is an operation, we check to make sure that we
// don't replace uses that are within the parent operation of the
// replacement value.
Operation *replOp = repl.cast<OpResult>().getOwner();
Block *replBlock = replOp->getBlock();
arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
}
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed.
@ -907,11 +937,13 @@ Block *ConversionPatternRewriter::applySignatureConversion(
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
for (auto &u : from.getUses()) {
if (u.getOwner() == to.getDefiningOp())
continue;
u.getOwner()->replaceUsesOfWith(from, to);
}
LLVM_DEBUG({
Operation *parentOp = from.getOwner()->getParentOp();
impl->logger.startLine() << "** Replace Argument : '" << from
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
impl->argReplacements.push_back(from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}

View File

@ -197,3 +197,17 @@ func @create_illegal_block() {
}) : () -> ()
return
}
// -----
// CHECK-LABEL: @undo_block_arg_replace
func @undo_block_arg_replace() {
"test.undo_block_arg_replace"() ({
^bb0(%arg0: i32):
// CHECK: ^bb0(%[[ARG:.*]]: i32):
// CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
return
}

View File

@ -238,6 +238,24 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
/// A simple pattern that tests the undo mechanism when replacing the uses of a
/// block argument.
struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace(MLIRContext *ctx)
: ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto illegalOp =
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
illegalOp);
rewriter.updateRootInPlace(op, [] {});
return success();
}
};
//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing
@ -449,12 +467,14 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
patterns.insert<
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockArgReplace, TestPassthroughInvalidOp,
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite>(
&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);