Change the `notifyRootUpdated` API to be transaction based.

This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation).

PiperOrigin-RevId: 286933674
This commit is contained in:
River Riddle 2019-12-23 13:05:38 -08:00 committed by A. Unique TensorFlower
parent a5d5d29125
commit 5d5bd2e1da
10 changed files with 199 additions and 87 deletions

View File

@ -61,6 +61,7 @@ class SuccessorRange final
public:
using RangeBaseT::RangeBaseT;
SuccessorRange(Block *block);
SuccessorRange(Operation *term);
private:
/// See `detail::indexed_accessor_range_base` for details.

View File

@ -385,6 +385,12 @@ public:
return {getTrailingObjects<BlockOperand>(), numSuccs};
}
// Successor iteration.
using succ_iterator = SuccessorRange::iterator;
succ_iterator successor_begin() { return getSuccessors().begin(); }
succ_iterator successor_end() { return getSuccessors().end(); }
SuccessorRange getSuccessors() { return SuccessorRange(this); }
/// Return the operands of this operation that are *not* successor arguments.
operand_range getNonSuccessorOperands();

View File

@ -361,15 +361,31 @@ public:
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
/// This method is used as the final notification hook for patterns that end
/// up modifying the pattern root in place, by changing its operands. This is
/// a minor efficiency win (it avoids creating a new operation and removing
/// the old one) but also often allows simpler code in the client.
///
/// The valuesToRemoveIfDead list is an optional list of values that the
/// rewriter should remove if they are dead at this point.
///
void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {});
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
/// This is a minor efficiency win (it avoids creating a new operation and
/// removing the old one) but also often allows simpler code in the client.
virtual void startRootUpdate(Operation *op) {}
/// This method is used to signal the end of a root update on the given
/// operation. This can only be called on operations that were provided to a
/// call to `startRootUpdate`.
virtual void finalizeRootUpdate(Operation *op) {}
/// This method cancels a pending root update. This can only be called on
/// operations that were provided to a call to `startRootUpdate`.
virtual void cancelRootUpdate(Operation *op) {}
/// This method is a utility wrapper around a root update of an operation. It
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
/// callable.
template <typename CallableT>
void updateRootInPlace(Operation *root, CallableT &&callable) {
startRootUpdate(root);
callable();
finalizeRootUpdate(root);
}
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
@ -378,10 +394,6 @@ protected:
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
/// Notify the pattern rewriter that the specified operation has been mutated
/// in place. This is called after the mutation is done.
virtual void notifyRootUpdated(Operation *op) {}
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.

View File

@ -365,7 +365,16 @@ public:
Operation *insert(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
void notifyRootUpdated(Operation *op) override;
/// Note: These methods only track updates to the top-level operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
void startRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
void finalizeRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
void cancelRootUpdate(Operation *op) override;
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();

View File

@ -54,13 +54,12 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
signatureConverter.addInputs(argType.index(), convertedType);
}
}
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
});
return matchSuccess();
}

View File

@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
PatternRewriter &rewriter) const override {
auto origInsertionPoint = rewriter.saveInsertionPoint();
rewriter.startRootUpdate(launchOp);
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&launchOp.body().front());
// Traverse operands passed to kernel and check if some of them are known
@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
// and use it instead of passing the value from the parent region. Perform
// the traversal in the inverse order to simplify index arithmetics when
// dropping arguments.
SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(),
launchOp.getKernelOperandValues().end());
SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(),
launchOp.getKernelArguments().end());
auto operands = launchOp.getKernelOperandValues();
auto kernelArgs = launchOp.getKernelArguments();
bool found = false;
for (unsigned i = operands.size(); i > 0; --i) {
unsigned index = i - 1;
ValuePtr operand = operands[index];
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
Value operand = operands[index];
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
continue;
}
found = true;
ValuePtr internalConstant =
Value internalConstant =
rewriter.clone(*operand->getDefiningOp())->getResult(0);
ValuePtr kernelArg = kernelArgs[index];
Value kernelArg = *std::next(kernelArgs.begin(), index);
kernelArg->replaceAllUsesWith(internalConstant);
launchOp.eraseKernelArgument(index);
}
rewriter.restoreInsertionPoint(origInsertionPoint);
if (!found)
if (!found) {
rewriter.cancelRootUpdate(launchOp);
return matchFailure();
}
rewriter.updatedRootInPlace(launchOp);
rewriter.finalizeRootUpdate(launchOp);
return matchSuccess();
}
};

View File

@ -197,13 +197,11 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
}
// Creates a new function with the update signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.eraseOp(funcOp.getOperation());
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
});
return matchSuccess();
}

View File

@ -267,3 +267,8 @@ SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) {
if ((count = term->getNumSuccessors()))
base = term->getBlockOperands().data();
}
SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
if ((count = term->getNumSuccessors()))
base = term->getBlockOperands().data();
}

View File

@ -170,23 +170,6 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
/// This method is used as the final notification hook for patterns that end
/// up modifying the pattern root in place, by changing its operands. This is
/// a minor efficiency win (it avoids creating a new operation and removing
/// the old one) but also often allows simpler code in the client.
///
/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
/// should remove if they are dead at this point.
///
void PatternRewriter::updatedRootInPlace(Operation *op,
ValueRange valuesToRemoveIfDead) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootUpdated(op);
// TODO: Process the valuesToRemoveIfDead list, removing things and calling
// the notifyOperationRemoved hook in the process.
}
//===----------------------------------------------------------------------===//
// PatternMatcher implementation
//===----------------------------------------------------------------------===//

View File

@ -406,14 +406,16 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
unsigned numBlockActions, unsigned numIgnoredOperations)
: numCreatedOperations(numCreatedOperations),
numReplacements(numReplacements), numBlockActions(numBlockActions),
numIgnoredOperations(numIgnoredOperations) {}
RewriterState(unsigned numCreatedOps, unsigned numReplacements,
unsigned numBlockActions, unsigned numIgnoredOperations,
unsigned numRootUpdates)
: numCreatedOps(numCreatedOps), numReplacements(numReplacements),
numBlockActions(numBlockActions),
numIgnoredOperations(numIgnoredOperations),
numRootUpdates(numRootUpdates) {}
/// The current number of created operations.
unsigned numCreatedOperations;
unsigned numCreatedOps;
/// The current number of replacements queued.
unsigned numReplacements;
@ -423,6 +425,41 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
/// The current number of operations that were updated in place.
unsigned numRootUpdates;
};
/// The state of an operation that was updated by a pattern in-place. This
/// contains all of the necessary information to reconstruct an operation that
/// was updated in place.
class OperationTransactionState {
public:
OperationTransactionState() = default;
OperationTransactionState(Operation *op)
: op(op), loc(op->getLoc()), attrs(op->getAttrList()),
operands(op->operand_begin(), op->operand_end()),
successors(op->successor_begin(), op->successor_end()) {}
/// Discard the transaction state and reset the state of the original
/// operation.
void resetOperation() const {
op->setLoc(loc);
op->setAttrs(attrs);
op->setOperands(operands);
for (auto it : llvm::enumerate(successors))
op->setSuccessor(it.value(), it.index());
}
/// Return the original operation of this state.
Operation *getOperation() const { return op; }
private:
Operation *op;
LocationAttr loc;
NamedAttributeList attrs;
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
};
} // end anonymous namespace
@ -567,16 +604,32 @@ struct ConversionPatternRewriterImpl {
/// the others. This simplifies the amount of memory needed as we can query if
/// the parent operation was ignored.
llvm::SetVector<Operation *> ignoredOps;
/// A transaction state for each of operations that were updated in-place.
SmallVector<OperationTransactionState, 4> rootUpdates;
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
/// verification.
SmallPtrSet<Operation *, 1> pendingRootUpdates;
#endif
};
} // end namespace detail
} // end namespace mlir
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
blockActions.size(), ignoredOps.size());
blockActions.size(), ignoredOps.size(),
rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Reset any operations that were updated in place.
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
rootUpdates[i].resetOperation();
rootUpdates.resize(state.numRootUpdates);
// Undo any block actions.
undoBlockActions(state.numBlockActions);
@ -587,7 +640,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
replacements.resize(state.numReplacements);
// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOperations) {
while (createdOps.size() != state.numCreatedOps) {
createdOps.back()->erase();
createdOps.pop_back();
}
@ -640,6 +693,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
}
void ConversionPatternRewriterImpl::discardRewrites() {
// Reset any operations that were updated in place.
for (auto &state : rootUpdates)
state.resetOperation();
undoBlockActions();
// Remove any newly created ops.
@ -867,11 +924,34 @@ Operation *ConversionPatternRewriter::insert(Operation *op) {
}
/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
// The rewriter caches changes to the IR to allow for operating in-place and
// backtracking. The rewriter is currently not capable of backtracking
// in-place modifications.
llvm_unreachable("in-place operation updates are not supported");
void ConversionPatternRewriter::startRootUpdate(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
impl->rootUpdates.emplace_back(op);
}
/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
}
/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
// Erase the last update for this operation.
auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
auto &rootUpdates = impl->rootUpdates;
auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
}
/// Return a reference to the internal implementation.
@ -1059,8 +1139,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (unsigned i = curState.numCreatedOperations,
e = rewriterImpl.createdOps.size();
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *cstOp = rewriterImpl.createdOps[i];
if (failed(legalize(cstOp, rewriter))) {
@ -1102,7 +1181,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
// Try to rewrite with the given pattern.
rewriter.setInsertionPoint(op);
if (!pattern->matchAndRewrite(op, rewriter)) {
auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
#ifndef NDEBUG
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
#endif
if (!matchedPattern) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
return cleanupFailure();
}
@ -1139,12 +1223,32 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
else
rewriterImpl.ignoredOps.insert(replacedOp);
}
assert(replacedRoot && "expected pattern to replace the root operation");
// Check that the root was either updated or replace.
auto updatedRootInPlace = [&] {
return llvm::any_of(
llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
[op](auto &state) { return state.getOperation() == op; });
};
(void)replacedRoot;
(void)updatedRootInPlace;
assert((replacedRoot || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
// Recursively legalize each of the operations updated in place.
for (unsigned i = curState.numRootUpdates,
e = rewriterImpl.rootUpdates.size();
i != e; ++i) {
auto &state = rewriterImpl.rootUpdates[i];
if (failed(legalize(state.getOperation(), rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '"
<< op->getName() << "' was illegal.\n");
return cleanupFailure();
}
}
// Recursively legalize each of the new operations.
for (unsigned i = curState.numCreatedOperations,
e = rewriterImpl.createdOps.size();
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *op = rewriterImpl.createdOps[i];
if (failed(legalize(op, rewriter))) {
@ -1534,16 +1638,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
convertedResults, funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
rewriter.eraseOp(funcOp);
// Update the function signature in-place.
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(FunctionType::get(result.getConvertedTypes(),
convertedResults, funcOp.getContext()));
rewriter.applySignatureConversion(&funcOp.getBody(), result);
});
return matchSuccess();
}