Add utility 'replaceAllUsesWith' methods to Operation.

These methods will allow replacing the uses of results with an existing operation, with the same number of results, or a range of values. This removes a number of hand-rolled result replacement loops and simplifies replacement for operations with multiple results.

PiperOrigin-RevId: 262206600
This commit is contained in:
River Riddle 2019-08-07 13:48:19 -07:00 committed by A. Unique TensorFlower
parent a477fbaf40
commit 8089f93746
6 changed files with 33 additions and 13 deletions

View File

@ -529,9 +529,10 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
/// Return the result at index 'i'.
Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
/// Set the result at index 'i' to 'value'.
void setResult(unsigned i, Value *value) {
this->getOperation()->setResult(i, value);
/// Replace all uses of results of this operation with the provided 'values'.
/// 'values' may correspond to an existing operation, or a range of 'Value'.
template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
}
/// Return the type of the `i`-th result.
@ -572,6 +573,11 @@ public:
getResult()->replaceAllUsesWith(newValue);
}
/// Replace all uses of 'this' value with the result of 'op'.
void replaceAllUsesWith(Operation *op) {
this->getOperation()->replaceAllUsesWith(op);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneResult(op);
}

View File

@ -137,6 +137,25 @@ public:
/// Replace any uses of 'from' with 'to' within this operation.
void replaceUsesOfWith(Value *from, Value *to);
/// Replace all uses of results of this operation with the provided 'values'.
template <typename ValuesT,
typename = decltype(std::declval<ValuesT>().begin())>
void replaceAllUsesWith(ValuesT &&values) {
assert(std::distance(values.begin(), values.end()) == getNumResults() &&
"expected 'values' to correspond 1-1 with the number of results");
auto valueIt = values.begin();
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
getResult(i)->replaceAllUsesWith(*(valueIt++));
}
/// Replace all uses of results of this operation with results of 'op'.
void replaceAllUsesWith(Operation *op) {
assert(getNumResults() == op->getNumResults());
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
getResult(i)->replaceAllUsesWith(op->getResult(i));
}
/// Destroys this operation and its subclass data.
void destroy();

View File

@ -91,8 +91,7 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
for (unsigned i = 0, e = newValues.size(); i != e; ++i)
op->getResult(i)->replaceAllUsesWith(newValues[i]);
op->replaceAllUsesWith(newValues);
notifyOperationRemoved(op);
op->erase();

View File

@ -150,8 +150,7 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
if (auto *existing = knownValues.lookup(op)) {
// If we find one then replace all uses of the current operation with the
// existing one and mark it for deletion.
for (unsigned i = 0, e = existing->getNumResults(); i != e; ++i)
op->getResult(i)->replaceAllUsesWith(existing->getResult(i));
op->replaceAllUsesWith(existing);
opsToErase.push_back(op);
// If the existing operation has an unknown location and the current

View File

@ -204,7 +204,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
// Perform the actual store to load forwarding.
Value *storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
loadOp.getResult()->replaceAllUsesWith(storeVal);
loadOp.replaceAllUsesWith(storeVal);
// Record the memref for a later sweep to optimize away.
memrefsToErase.insert(loadOp.getMemRef());
// Record this to erase later.

View File

@ -242,11 +242,8 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
// Create the new operation.
auto *repOp = builder.createOperation(state);
// Replace old memref's deferencing op's uses.
unsigned r = 0;
for (auto *res : opInst->getResults()) {
res->replaceAllUsesWith(repOp->getResult(r++));
}
opInst->replaceAllUsesWith(repOp);
// Collect and erase at the end since one of these op's could be
// domInstFilter or postDomInstFilter as well!
opsToErase.push_back(opInst);