Refactor getUsedValuesDefinedAbove to expose a variant taking a callback (NFC)

This will allow clients to implement a different collection strategy on these
values, including collecting each uses within the region for example.

PiperOrigin-RevId: 267803978
This commit is contained in:
Mehdi Amini 2019-09-07 17:02:07 -07:00 committed by A. Unique TensorFlower
parent 713ab0dde7
commit 6443583bfd
2 changed files with 32 additions and 8 deletions

View File

@ -40,6 +40,16 @@ bool areValuesDefinedAbove(Range values, Region &limit) {
void replaceAllUsesInRegionWith(Value *orig, Value *replacement,
Region &region);
/// Calls `callback` for each use of a value within `region` or its descendants
/// that was defined at the ancestors of the `limit`.
void visitUsedValuesDefinedAbove(Region &region, Region &limit,
function_ref<void(OpOperand *)> callback);
/// Calls `callback` for each use of a value within any of the regions provided
/// that was defined in one of the ancestors.
void visitUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions,
function_ref<void(OpOperand *)> callback);
/// Fill `values` with a list of values defined at the ancestors of the `limit`
/// region and used within `region` or its descendants.
void getUsedValuesDefinedAbove(Region &region, Region &limit,

View File

@ -32,8 +32,9 @@ void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement,
}
}
void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
llvm::SetVector<Value *> &values) {
void mlir::visitUsedValuesDefinedAbove(
Region &region, Region &limit,
llvm::function_ref<void(OpOperand *)> callback) {
assert(limit.isAncestor(&region) &&
"expected isolation limit to be an ancestor of the given region");
@ -45,12 +46,25 @@ void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
properAncestors.insert(reg);
}
region.walk([&values, &properAncestors](Operation *op) {
for (Value *operand : op->getOperands())
// Collect values that are used by an operation and defined in a proper
// ancestor of region.
if (properAncestors.count(operand->getParentRegion()))
values.insert(operand);
region.walk([callback, &properAncestors](Operation *op) {
for (OpOperand &operand : op->getOpOperands())
// Callback on values defined in a proper ancestor of region.
if (properAncestors.count(operand.get()->getParentRegion()))
callback(&operand);
});
}
void mlir::visitUsedValuesDefinedAbove(
llvm::MutableArrayRef<Region> regions,
llvm::function_ref<void(OpOperand *)> callback) {
for (Region &region : regions)
visitUsedValuesDefinedAbove(region, region, callback);
}
void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
llvm::SetVector<Value *> &values) {
visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
values.insert(operand->get());
});
}