diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 6316b566373a..10e6dfbae5e1 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -40,6 +40,16 @@ bool areValuesDefinedAbove(Range values, Region &limit) { void replaceAllUsesInRegionWith(Value *orig, Value *replacement, Region ®ion); +/// Calls `callback` for each use of a value within `region` or its descendants +/// that was defined at the ancestors of the `limit`. +void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, + function_ref callback); + +/// Calls `callback` for each use of a value within any of the regions provided +/// that was defined in one of the ancestors. +void visitUsedValuesDefinedAbove(llvm::MutableArrayRef regions, + function_ref callback); + /// Fill `values` with a list of values defined at the ancestors of the `limit` /// region and used within `region` or its descendants. void getUsedValuesDefinedAbove(Region ®ion, Region &limit, diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 9974e47c2c1a..24c38c4f2b96 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -32,8 +32,9 @@ void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement, } } -void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, - llvm::SetVector &values) { +void mlir::visitUsedValuesDefinedAbove( + Region ®ion, Region &limit, + llvm::function_ref callback) { assert(limit.isAncestor(®ion) && "expected isolation limit to be an ancestor of the given region"); @@ -45,12 +46,25 @@ void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, properAncestors.insert(reg); } - region.walk([&values, &properAncestors](Operation *op) { - for (Value *operand : op->getOperands()) - // Collect values that are used by an operation and defined in a proper - // ancestor of region. - if (properAncestors.count(operand->getParentRegion())) - values.insert(operand); + region.walk([callback, &properAncestors](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) + // Callback on values defined in a proper ancestor of region. + if (properAncestors.count(operand.get()->getParentRegion())) + callback(&operand); + }); +} + +void mlir::visitUsedValuesDefinedAbove( + llvm::MutableArrayRef regions, + llvm::function_ref callback) { + for (Region ®ion : regions) + visitUsedValuesDefinedAbove(region, region, callback); +} + +void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit, + llvm::SetVector &values) { + visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) { + values.insert(operand->get()); }); }