forked from OSchip/llvm-project
Change the `notifyRootUpdated` API to be transaction based.
This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation). PiperOrigin-RevId: 286933674
This commit is contained in:
parent
a5d5d29125
commit
5d5bd2e1da
|
@ -61,6 +61,7 @@ class SuccessorRange final
|
|||
public:
|
||||
using RangeBaseT::RangeBaseT;
|
||||
SuccessorRange(Block *block);
|
||||
SuccessorRange(Operation *term);
|
||||
|
||||
private:
|
||||
/// See `detail::indexed_accessor_range_base` for details.
|
||||
|
|
|
@ -385,6 +385,12 @@ public:
|
|||
return {getTrailingObjects<BlockOperand>(), numSuccs};
|
||||
}
|
||||
|
||||
// Successor iteration.
|
||||
using succ_iterator = SuccessorRange::iterator;
|
||||
succ_iterator successor_begin() { return getSuccessors().begin(); }
|
||||
succ_iterator successor_end() { return getSuccessors().end(); }
|
||||
SuccessorRange getSuccessors() { return SuccessorRange(this); }
|
||||
|
||||
/// Return the operands of this operation that are *not* successor arguments.
|
||||
operand_range getNonSuccessorOperands();
|
||||
|
||||
|
|
|
@ -361,15 +361,31 @@ public:
|
|||
/// block into a new block, and return it.
|
||||
virtual Block *splitBlock(Block *block, Block::iterator before);
|
||||
|
||||
/// This method is used as the final notification hook for patterns that end
|
||||
/// up modifying the pattern root in place, by changing its operands. This is
|
||||
/// a minor efficiency win (it avoids creating a new operation and removing
|
||||
/// the old one) but also often allows simpler code in the client.
|
||||
///
|
||||
/// The valuesToRemoveIfDead list is an optional list of values that the
|
||||
/// rewriter should remove if they are dead at this point.
|
||||
///
|
||||
void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {});
|
||||
/// This method is used to notify the rewriter that an in-place operation
|
||||
/// modification is about to happen. A call to this function *must* be
|
||||
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
|
||||
/// This is a minor efficiency win (it avoids creating a new operation and
|
||||
/// removing the old one) but also often allows simpler code in the client.
|
||||
virtual void startRootUpdate(Operation *op) {}
|
||||
|
||||
/// This method is used to signal the end of a root update on the given
|
||||
/// operation. This can only be called on operations that were provided to a
|
||||
/// call to `startRootUpdate`.
|
||||
virtual void finalizeRootUpdate(Operation *op) {}
|
||||
|
||||
/// This method cancels a pending root update. This can only be called on
|
||||
/// operations that were provided to a call to `startRootUpdate`.
|
||||
virtual void cancelRootUpdate(Operation *op) {}
|
||||
|
||||
/// This method is a utility wrapper around a root update of an operation. It
|
||||
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
|
||||
/// callable.
|
||||
template <typename CallableT>
|
||||
void updateRootInPlace(Operation *root, CallableT &&callable) {
|
||||
startRootUpdate(root);
|
||||
callable();
|
||||
finalizeRootUpdate(root);
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
|
||||
|
@ -378,10 +394,6 @@ protected:
|
|||
// These are the callback methods that subclasses can choose to implement if
|
||||
// they would like to be notified about certain types of mutations.
|
||||
|
||||
/// Notify the pattern rewriter that the specified operation has been mutated
|
||||
/// in place. This is called after the mutation is done.
|
||||
virtual void notifyRootUpdated(Operation *op) {}
|
||||
|
||||
/// Notify the pattern rewriter that the specified operation is about to be
|
||||
/// replaced with another set of operations. This is called before the uses
|
||||
/// of the operation have been changed.
|
||||
|
|
|
@ -365,7 +365,16 @@ public:
|
|||
Operation *insert(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void notifyRootUpdated(Operation *op) override;
|
||||
/// Note: These methods only track updates to the top-level operation itself,
|
||||
/// and not nested regions. Updates to regions will still require notification
|
||||
/// through other more specific hooks above.
|
||||
void startRootUpdate(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void finalizeRootUpdate(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void cancelRootUpdate(Operation *op) override;
|
||||
|
||||
/// Return a reference to the internal implementation.
|
||||
detail::ConversionPatternRewriterImpl &getImpl();
|
||||
|
|
|
@ -54,13 +54,12 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
|
|||
signatureConverter.addInputs(argType.index(), convertedType);
|
||||
}
|
||||
}
|
||||
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
newFuncOp.setType(rewriter.getFunctionType(
|
||||
signatureConverter.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
|
||||
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(rewriter.getFunctionType(
|
||||
signatureConverter.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
|
|
@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
|
|||
|
||||
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto origInsertionPoint = rewriter.saveInsertionPoint();
|
||||
rewriter.startRootUpdate(launchOp);
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&launchOp.body().front());
|
||||
|
||||
// Traverse operands passed to kernel and check if some of them are known
|
||||
|
@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
|
|||
// and use it instead of passing the value from the parent region. Perform
|
||||
// the traversal in the inverse order to simplify index arithmetics when
|
||||
// dropping arguments.
|
||||
SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(),
|
||||
launchOp.getKernelOperandValues().end());
|
||||
SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(),
|
||||
launchOp.getKernelArguments().end());
|
||||
auto operands = launchOp.getKernelOperandValues();
|
||||
auto kernelArgs = launchOp.getKernelArguments();
|
||||
bool found = false;
|
||||
for (unsigned i = operands.size(); i > 0; --i) {
|
||||
unsigned index = i - 1;
|
||||
ValuePtr operand = operands[index];
|
||||
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
|
||||
Value operand = operands[index];
|
||||
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
|
||||
continue;
|
||||
}
|
||||
|
||||
found = true;
|
||||
ValuePtr internalConstant =
|
||||
Value internalConstant =
|
||||
rewriter.clone(*operand->getDefiningOp())->getResult(0);
|
||||
ValuePtr kernelArg = kernelArgs[index];
|
||||
Value kernelArg = *std::next(kernelArgs.begin(), index);
|
||||
kernelArg->replaceAllUsesWith(internalConstant);
|
||||
launchOp.eraseKernelArgument(index);
|
||||
}
|
||||
rewriter.restoreInsertionPoint(origInsertionPoint);
|
||||
|
||||
if (!found)
|
||||
if (!found) {
|
||||
rewriter.cancelRootUpdate(launchOp);
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
rewriter.updatedRootInPlace(launchOp);
|
||||
rewriter.finalizeRootUpdate(launchOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -197,13 +197,11 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
|
|||
}
|
||||
|
||||
// Creates a new function with the update signature.
|
||||
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
newFuncOp.setType(rewriter.getFunctionType(
|
||||
signatureConverter.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||
rewriter.eraseOp(funcOp.getOperation());
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(rewriter.getFunctionType(
|
||||
signatureConverter.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
|
|
@ -267,3 +267,8 @@ SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) {
|
|||
if ((count = term->getNumSuccessors()))
|
||||
base = term->getBlockOperands().data();
|
||||
}
|
||||
|
||||
SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
|
||||
if ((count = term->getNumSuccessors()))
|
||||
base = term->getBlockOperands().data();
|
||||
}
|
||||
|
|
|
@ -170,23 +170,6 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) {
|
|||
cloneRegionBefore(region, *before->getParent(), before->getIterator());
|
||||
}
|
||||
|
||||
/// This method is used as the final notification hook for patterns that end
|
||||
/// up modifying the pattern root in place, by changing its operands. This is
|
||||
/// a minor efficiency win (it avoids creating a new operation and removing
|
||||
/// the old one) but also often allows simpler code in the client.
|
||||
///
|
||||
/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
|
||||
/// should remove if they are dead at this point.
|
||||
///
|
||||
void PatternRewriter::updatedRootInPlace(Operation *op,
|
||||
ValueRange valuesToRemoveIfDead) {
|
||||
// Notify the rewriter subclass that we're about to replace this root.
|
||||
notifyRootUpdated(op);
|
||||
|
||||
// TODO: Process the valuesToRemoveIfDead list, removing things and calling
|
||||
// the notifyOperationRemoved hook in the process.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternMatcher implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -406,14 +406,16 @@ namespace {
|
|||
/// This class contains a snapshot of the current conversion rewriter state.
|
||||
/// This is useful when saving and undoing a set of rewrites.
|
||||
struct RewriterState {
|
||||
RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
|
||||
unsigned numBlockActions, unsigned numIgnoredOperations)
|
||||
: numCreatedOperations(numCreatedOperations),
|
||||
numReplacements(numReplacements), numBlockActions(numBlockActions),
|
||||
numIgnoredOperations(numIgnoredOperations) {}
|
||||
RewriterState(unsigned numCreatedOps, unsigned numReplacements,
|
||||
unsigned numBlockActions, unsigned numIgnoredOperations,
|
||||
unsigned numRootUpdates)
|
||||
: numCreatedOps(numCreatedOps), numReplacements(numReplacements),
|
||||
numBlockActions(numBlockActions),
|
||||
numIgnoredOperations(numIgnoredOperations),
|
||||
numRootUpdates(numRootUpdates) {}
|
||||
|
||||
/// The current number of created operations.
|
||||
unsigned numCreatedOperations;
|
||||
unsigned numCreatedOps;
|
||||
|
||||
/// The current number of replacements queued.
|
||||
unsigned numReplacements;
|
||||
|
@ -423,6 +425,41 @@ struct RewriterState {
|
|||
|
||||
/// The current number of ignored operations.
|
||||
unsigned numIgnoredOperations;
|
||||
|
||||
/// The current number of operations that were updated in place.
|
||||
unsigned numRootUpdates;
|
||||
};
|
||||
|
||||
/// The state of an operation that was updated by a pattern in-place. This
|
||||
/// contains all of the necessary information to reconstruct an operation that
|
||||
/// was updated in place.
|
||||
class OperationTransactionState {
|
||||
public:
|
||||
OperationTransactionState() = default;
|
||||
OperationTransactionState(Operation *op)
|
||||
: op(op), loc(op->getLoc()), attrs(op->getAttrList()),
|
||||
operands(op->operand_begin(), op->operand_end()),
|
||||
successors(op->successor_begin(), op->successor_end()) {}
|
||||
|
||||
/// Discard the transaction state and reset the state of the original
|
||||
/// operation.
|
||||
void resetOperation() const {
|
||||
op->setLoc(loc);
|
||||
op->setAttrs(attrs);
|
||||
op->setOperands(operands);
|
||||
for (auto it : llvm::enumerate(successors))
|
||||
op->setSuccessor(it.value(), it.index());
|
||||
}
|
||||
|
||||
/// Return the original operation of this state.
|
||||
Operation *getOperation() const { return op; }
|
||||
|
||||
private:
|
||||
Operation *op;
|
||||
LocationAttr loc;
|
||||
NamedAttributeList attrs;
|
||||
SmallVector<Value, 8> operands;
|
||||
SmallVector<Block *, 2> successors;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -567,16 +604,32 @@ struct ConversionPatternRewriterImpl {
|
|||
/// the others. This simplifies the amount of memory needed as we can query if
|
||||
/// the parent operation was ignored.
|
||||
llvm::SetVector<Operation *> ignoredOps;
|
||||
|
||||
/// A transaction state for each of operations that were updated in-place.
|
||||
SmallVector<OperationTransactionState, 4> rootUpdates;
|
||||
|
||||
#ifndef NDEBUG
|
||||
/// A set of operations that have pending updates. This tracking isn't
|
||||
/// strictly necessary, and is thus only active during debug builds for extra
|
||||
/// verification.
|
||||
SmallPtrSet<Operation *, 1> pendingRootUpdates;
|
||||
#endif
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace mlir
|
||||
|
||||
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
|
||||
return RewriterState(createdOps.size(), replacements.size(),
|
||||
blockActions.size(), ignoredOps.size());
|
||||
blockActions.size(), ignoredOps.size(),
|
||||
rootUpdates.size());
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
|
||||
// Reset any operations that were updated in place.
|
||||
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
|
||||
rootUpdates[i].resetOperation();
|
||||
rootUpdates.resize(state.numRootUpdates);
|
||||
|
||||
// Undo any block actions.
|
||||
undoBlockActions(state.numBlockActions);
|
||||
|
||||
|
@ -587,7 +640,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
|
|||
replacements.resize(state.numReplacements);
|
||||
|
||||
// Pop all of the newly created operations.
|
||||
while (createdOps.size() != state.numCreatedOperations) {
|
||||
while (createdOps.size() != state.numCreatedOps) {
|
||||
createdOps.back()->erase();
|
||||
createdOps.pop_back();
|
||||
}
|
||||
|
@ -640,6 +693,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
|
|||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::discardRewrites() {
|
||||
// Reset any operations that were updated in place.
|
||||
for (auto &state : rootUpdates)
|
||||
state.resetOperation();
|
||||
|
||||
undoBlockActions();
|
||||
|
||||
// Remove any newly created ops.
|
||||
|
@ -867,11 +924,34 @@ Operation *ConversionPatternRewriter::insert(Operation *op) {
|
|||
}
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
|
||||
// The rewriter caches changes to the IR to allow for operating in-place and
|
||||
// backtracking. The rewriter is currently not capable of backtracking
|
||||
// in-place modifications.
|
||||
llvm_unreachable("in-place operation updates are not supported");
|
||||
void ConversionPatternRewriter::startRootUpdate(Operation *op) {
|
||||
#ifndef NDEBUG
|
||||
impl->pendingRootUpdates.insert(op);
|
||||
#endif
|
||||
impl->rootUpdates.emplace_back(op);
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
|
||||
// There is nothing to do here, we only need to track the operation at the
|
||||
// start of the update.
|
||||
#ifndef NDEBUG
|
||||
assert(impl->pendingRootUpdates.erase(op) &&
|
||||
"operation did not have a pending in-place update");
|
||||
#endif
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
|
||||
#ifndef NDEBUG
|
||||
assert(impl->pendingRootUpdates.erase(op) &&
|
||||
"operation did not have a pending in-place update");
|
||||
#endif
|
||||
// Erase the last update for this operation.
|
||||
auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
|
||||
auto &rootUpdates = impl->rootUpdates;
|
||||
auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
|
||||
rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
|
||||
}
|
||||
|
||||
/// Return a reference to the internal implementation.
|
||||
|
@ -1059,8 +1139,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
|
|||
rewriter.replaceOp(op, replacementValues);
|
||||
|
||||
// Recursively legalize any new constant operations.
|
||||
for (unsigned i = curState.numCreatedOperations,
|
||||
e = rewriterImpl.createdOps.size();
|
||||
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
|
||||
i != e; ++i) {
|
||||
Operation *cstOp = rewriterImpl.createdOps[i];
|
||||
if (failed(legalize(cstOp, rewriter))) {
|
||||
|
@ -1102,7 +1181,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
|
|||
|
||||
// Try to rewrite with the given pattern.
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (!pattern->matchAndRewrite(op, rewriter)) {
|
||||
auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
|
||||
#ifndef NDEBUG
|
||||
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
|
||||
#endif
|
||||
|
||||
if (!matchedPattern) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
|
||||
return cleanupFailure();
|
||||
}
|
||||
|
@ -1139,12 +1223,32 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
|
|||
else
|
||||
rewriterImpl.ignoredOps.insert(replacedOp);
|
||||
}
|
||||
assert(replacedRoot && "expected pattern to replace the root operation");
|
||||
|
||||
// Check that the root was either updated or replace.
|
||||
auto updatedRootInPlace = [&] {
|
||||
return llvm::any_of(
|
||||
llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
|
||||
[op](auto &state) { return state.getOperation() == op; });
|
||||
};
|
||||
(void)replacedRoot;
|
||||
(void)updatedRootInPlace;
|
||||
assert((replacedRoot || updatedRootInPlace()) &&
|
||||
"expected pattern to replace the root operation");
|
||||
|
||||
// Recursively legalize each of the operations updated in place.
|
||||
for (unsigned i = curState.numRootUpdates,
|
||||
e = rewriterImpl.rootUpdates.size();
|
||||
i != e; ++i) {
|
||||
auto &state = rewriterImpl.rootUpdates[i];
|
||||
if (failed(legalize(state.getOperation(), rewriter))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '"
|
||||
<< op->getName() << "' was illegal.\n");
|
||||
return cleanupFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively legalize each of the new operations.
|
||||
for (unsigned i = curState.numCreatedOperations,
|
||||
e = rewriterImpl.createdOps.size();
|
||||
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
|
||||
i != e; ++i) {
|
||||
Operation *op = rewriterImpl.createdOps[i];
|
||||
if (failed(legalize(op, rewriter))) {
|
||||
|
@ -1534,16 +1638,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
|||
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
|
||||
return matchFailure();
|
||||
|
||||
// Create a new function with an updated signature.
|
||||
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||
newFuncOp.end());
|
||||
newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
|
||||
convertedResults, funcOp.getContext()));
|
||||
|
||||
// Tell the rewriter to convert the region signature.
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
|
||||
rewriter.eraseOp(funcOp);
|
||||
// Update the function signature in-place.
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(FunctionType::get(result.getConvertedTypes(),
|
||||
convertedResults, funcOp.getContext()));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), result);
|
||||
});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue