forked from OSchip/llvm-project
[mlir][DialectConversion] Add support for properly tracking replaceUsesOfBlockArgument
The current implementation of this method performs the replacement directly, and thus doesn't support proper back tracking. Differential Revision: https://reviews.llvm.org/D78790
This commit is contained in:
parent
4de60d955a
commit
0816de167a
|
@ -145,6 +145,11 @@ public:
|
|||
replaceAllUsesExcept(Value newValue,
|
||||
const SmallPtrSetImpl<Operation *> &exceptions) const;
|
||||
|
||||
/// Replace all uses of 'this' value with 'newValue' if the given callback
|
||||
/// returns true.
|
||||
void replaceUsesWithIf(Value newValue,
|
||||
function_ref<bool(OpOperand &)> shouldReplace);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Uses
|
||||
|
||||
|
|
|
@ -125,6 +125,15 @@ void Value::replaceAllUsesExcept(
|
|||
}
|
||||
}
|
||||
|
||||
/// Replace all uses of 'this' value with 'newValue' if the given callback
|
||||
/// returns true.
|
||||
void Value::replaceUsesWithIf(Value newValue,
|
||||
function_ref<bool(OpOperand &)> shouldReplace) {
|
||||
for (OpOperand &use : llvm::make_early_inc_range(getUses()))
|
||||
if (shouldReplace(use))
|
||||
use.set(newValue);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Uses
|
||||
|
||||
|
|
|
@ -197,8 +197,6 @@ struct ArgConverter {
|
|||
|
||||
/// Fully replace uses of the old arguments with the new, materializing cast
|
||||
/// operations as necessary.
|
||||
// FIXME(riverriddle) The 'mapping' parameter is only necessary because the
|
||||
// implementation of replaceUsesOfBlockArgument is buggy.
|
||||
void applyRewrites(ConversionValueMapping &mapping);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -436,9 +434,10 @@ namespace {
|
|||
/// This is useful when saving and undoing a set of rewrites.
|
||||
struct RewriterState {
|
||||
RewriterState(unsigned numCreatedOps, unsigned numReplacements,
|
||||
unsigned numBlockActions, unsigned numIgnoredOperations,
|
||||
unsigned numRootUpdates)
|
||||
unsigned numArgReplacements, unsigned numBlockActions,
|
||||
unsigned numIgnoredOperations, unsigned numRootUpdates)
|
||||
: numCreatedOps(numCreatedOps), numReplacements(numReplacements),
|
||||
numArgReplacements(numArgReplacements),
|
||||
numBlockActions(numBlockActions),
|
||||
numIgnoredOperations(numIgnoredOperations),
|
||||
numRootUpdates(numRootUpdates) {}
|
||||
|
@ -449,6 +448,9 @@ struct RewriterState {
|
|||
/// The current number of replacements queued.
|
||||
unsigned numReplacements;
|
||||
|
||||
/// The current number of argument replacements queued.
|
||||
unsigned numArgReplacements;
|
||||
|
||||
/// The current number of block actions performed.
|
||||
unsigned numBlockActions;
|
||||
|
||||
|
@ -624,6 +626,9 @@ struct ConversionPatternRewriterImpl {
|
|||
/// Ordered vector of any requested operation replacements.
|
||||
SmallVector<OpReplacement, 4> replacements;
|
||||
|
||||
/// Ordered vector of any requested block argument replacements.
|
||||
SmallVector<BlockArgument, 4> argReplacements;
|
||||
|
||||
/// Ordered list of block operations (creations, splits, motions).
|
||||
SmallVector<BlockAction, 4> blockActions;
|
||||
|
||||
|
@ -654,8 +659,8 @@ struct ConversionPatternRewriterImpl {
|
|||
|
||||
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
|
||||
return RewriterState(createdOps.size(), replacements.size(),
|
||||
blockActions.size(), ignoredOps.size(),
|
||||
rootUpdates.size());
|
||||
argReplacements.size(), blockActions.size(),
|
||||
ignoredOps.size(), rootUpdates.size());
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
|
||||
|
@ -664,6 +669,12 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
|
|||
rootUpdates[i].resetOperation();
|
||||
rootUpdates.resize(state.numRootUpdates);
|
||||
|
||||
// Reset any replaced arguments.
|
||||
for (BlockArgument replacedArg :
|
||||
llvm::drop_begin(argReplacements, state.numArgReplacements))
|
||||
mapping.erase(replacedArg);
|
||||
argReplacements.resize(state.numArgReplacements);
|
||||
|
||||
// Undo any block actions.
|
||||
undoBlockActions(state.numBlockActions);
|
||||
|
||||
|
@ -753,6 +764,25 @@ void ConversionPatternRewriterImpl::applyRewrites() {
|
|||
argConverter.notifyOpRemoved(repl.op);
|
||||
}
|
||||
|
||||
// Apply all of the requested argument replacements.
|
||||
for (BlockArgument arg : argReplacements) {
|
||||
Value repl = mapping.lookupOrDefault(arg);
|
||||
if (repl.isa<BlockArgument>()) {
|
||||
arg.replaceAllUsesWith(repl);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the replacement value is an operation, we check to make sure that we
|
||||
// don't replace uses that are within the parent operation of the
|
||||
// replacement value.
|
||||
Operation *replOp = repl.cast<OpResult>().getOwner();
|
||||
Block *replBlock = replOp->getBlock();
|
||||
arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
|
||||
Operation *user = operand.getOwner();
|
||||
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
|
||||
});
|
||||
}
|
||||
|
||||
// In a second pass, erase all of the replaced operations in reverse. This
|
||||
// allows processing nested operations before their parent region is
|
||||
// destroyed.
|
||||
|
@ -907,11 +937,13 @@ Block *ConversionPatternRewriter::applySignatureConversion(
|
|||
|
||||
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
|
||||
Value to) {
|
||||
for (auto &u : from.getUses()) {
|
||||
if (u.getOwner() == to.getDefiningOp())
|
||||
continue;
|
||||
u.getOwner()->replaceUsesOfWith(from, to);
|
||||
}
|
||||
LLVM_DEBUG({
|
||||
Operation *parentOp = from.getOwner()->getParentOp();
|
||||
impl->logger.startLine() << "** Replace Argument : '" << from
|
||||
<< "'(in region of '" << parentOp->getName()
|
||||
<< "'(" << from.getOwner()->getParentOp() << ")\n";
|
||||
});
|
||||
impl->argReplacements.push_back(from);
|
||||
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
|
||||
}
|
||||
|
||||
|
|
|
@ -197,3 +197,17 @@ func @create_illegal_block() {
|
|||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @undo_block_arg_replace
|
||||
func @undo_block_arg_replace() {
|
||||
"test.undo_block_arg_replace"() ({
|
||||
^bb0(%arg0: i32):
|
||||
// CHECK: ^bb0(%[[ARG:.*]]: i32):
|
||||
// CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
|
||||
|
||||
"test.return"(%arg0) : (i32) -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -238,6 +238,24 @@ struct TestCreateIllegalBlock : public RewritePattern {
|
|||
}
|
||||
};
|
||||
|
||||
/// A simple pattern that tests the undo mechanism when replacing the uses of a
|
||||
/// block argument.
|
||||
struct TestUndoBlockArgReplace : public ConversionPattern {
|
||||
TestUndoBlockArgReplace(MLIRContext *ctx)
|
||||
: ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto illegalOp =
|
||||
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
|
||||
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
|
||||
illegalOp);
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-Conversion Rewrite Testing
|
||||
|
||||
|
@ -449,12 +467,14 @@ struct TestLegalizePatternDriver
|
|||
TestTypeConverter converter;
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.insert<
|
||||
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
|
||||
TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
|
||||
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
|
||||
TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
|
||||
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestCreateBlock, TestCreateIllegalBlock,
|
||||
TestUndoBlockArgReplace, TestPassthroughInvalidOp,
|
||||
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
|
||||
TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
|
||||
TestNonRootReplacement, TestBoundedRecursiveRewrite>(
|
||||
&getContext());
|
||||
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||
converter);
|
||||
|
|
Loading…
Reference in New Issue