[MLIR] Support walks over regions and blocks

Add specializations for `walk` to allow traversal of regions and blocks.

Differential Revision: https://reviews.llvm.org/D90379
This commit is contained in:
Frederik Gossen 2020-10-29 13:48:07 +00:00
parent 8c058dd2d7
commit dbae3d50f1
6 changed files with 117 additions and 67 deletions

View File

@ -86,7 +86,7 @@ public:
private: private:
/// Initializes the internal mappings. /// Initializes the internal mappings.
void build(MutableArrayRef<Region> regions); void build();
private: private:
/// The operation this analysis was constructed from. /// The operation this analysis was constructed from.

View File

@ -254,7 +254,7 @@ public:
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
walk(Block::iterator begin, Block::iterator end, FnT &&callback) { walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
detail::walkOperations(&op, callback); detail::walk(&op, callback);
} }
/// Walk the operations in the specified [begin, end) range of this block in /// Walk the operations in the specified [begin, end) range of this block in
@ -265,7 +265,7 @@ public:
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
walk(Block::iterator begin, Block::iterator end, FnT &&callback) { walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
if (detail::walkOperations(&op, callback).wasInterrupted()) if (detail::walk(&op, callback).wasInterrupted())
return WalkResult::interrupt(); return WalkResult::interrupt();
return WalkResult::advance(); return WalkResult::advance();
} }

View File

@ -520,7 +520,7 @@ public:
/// }); /// });
template <typename FnT, typename RetT = detail::walkResultType<FnT>> template <typename FnT, typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) { RetT walk(FnT &&callback) {
return detail::walkOperations(this, std::forward<FnT>(callback)); return detail::walk(this, std::forward<FnT>(callback));
} }
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//

View File

@ -21,6 +21,8 @@ namespace mlir {
class Diagnostic; class Diagnostic;
class InFlightDiagnostic; class InFlightDiagnostic;
class Operation; class Operation;
class Block;
class Region;
/// A utility result that is used to signal if a walk method should be /// A utility result that is used to signal if a walk method should be
/// interrupted or advance. /// interrupted or advance.
@ -61,31 +63,41 @@ decltype(first_argument_type(&F::operator())) first_argument_type(F);
template <typename T> template <typename T>
using first_argument = decltype(first_argument_type(std::declval<T>())); using first_argument = decltype(first_argument_type(std::declval<T>()));
/// Walk all of the operations nested under and including the given operation. /// Walk all of the regions, blocks, or operations nested under (and including)
void walkOperations(Operation *op, function_ref<void(Operation *op)> callback); /// the given operation.
void walk(Operation *op, function_ref<void(Region *)> callback);
void walk(Operation *op, function_ref<void(Block *)> callback);
void walk(Operation *op, function_ref<void(Operation *)> callback);
/// Walk all of the operations nested under and including the given operation. /// Walk all of the regions, blocks, or operations nested under (and including)
/// This methods walks operations until an interrupt result is returned by the /// the given operation. These functions walk until an interrupt result is
/// callback. /// returned by the callback.
WalkResult walkOperations(Operation *op, WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
function_ref<WalkResult(Operation *op)> callback); WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
// Below are a set of functions to walk nested operations. Users should favor // Below are a set of functions to walk nested operations. Users should favor
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
// methods. They are also templated to allow for statically dispatching based // methods. They are also templated to allow for statically dispatching based
// upon the type of the callback function. // upon the type of the callback function.
/// Walk all of the operations nested under and including the given operation. /// Walk all of the regions, blocks, or operations nested under (and including)
/// This method is selected for callbacks that operate on Operation*. /// the given operation. This method is selected for callbacks that operate on
/// Region*, Block*, and Operation*.
/// ///
/// Example: /// Example:
/// op->walk([](Region *r) { ... });
/// op->walk([](Block *b) { ... });
/// op->walk([](Operation *op) { ... }); /// op->walk([](Operation *op) { ... });
template < template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))> typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type typename std::enable_if<std::is_same<ArgT, Operation *>::value ||
walkOperations(Operation *op, FuncTy &&callback) { std::is_same<ArgT, Region *>::value ||
return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback)); std::is_same<ArgT, Block *>::value,
RetT>::type
walk(Operation *op, FuncTy &&callback) {
return walk(op, function_ref<RetT(ArgT)>(callback));
} }
/// Walk all of the operations of type 'ArgT' nested under and including the /// Walk all of the operations of type 'ArgT' nested under and including the
@ -98,14 +110,16 @@ template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))> typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<!std::is_same<ArgT, Operation *>::value && typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
!std::is_same<ArgT, Region *>::value &&
!std::is_same<ArgT, Block *>::value &&
std::is_same<RetT, void>::value, std::is_same<RetT, void>::value,
RetT>::type RetT>::type
walkOperations(Operation *op, FuncTy &&callback) { walk(Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](Operation *op) { auto wrapperFn = [&](Operation *op) {
if (auto derivedOp = dyn_cast<ArgT>(op)) if (auto derivedOp = dyn_cast<ArgT>(op))
callback(derivedOp); callback(derivedOp);
}; };
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn)); return walk(op, function_ref<RetT(Operation *)>(wrapperFn));
} }
/// Walk all of the operations of type 'ArgT' nested under and including the /// Walk all of the operations of type 'ArgT' nested under and including the
@ -122,20 +136,22 @@ template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))> typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<!std::is_same<ArgT, Operation *>::value && typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
!std::is_same<ArgT, Region *>::value &&
!std::is_same<ArgT, Block *>::value &&
std::is_same<RetT, WalkResult>::value, std::is_same<RetT, WalkResult>::value,
RetT>::type RetT>::type
walkOperations(Operation *op, FuncTy &&callback) { walk(Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](Operation *op) { auto wrapperFn = [&](Operation *op) {
if (auto derivedOp = dyn_cast<ArgT>(op)) if (auto derivedOp = dyn_cast<ArgT>(op))
return callback(derivedOp); return callback(derivedOp);
return WalkResult::advance(); return WalkResult::advance();
}; };
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn)); return walk(op, function_ref<RetT(Operation *)>(wrapperFn));
} }
/// Utility to provide the return type of a templated walk method. /// Utility to provide the return type of a templated walk method.
template <typename FnT> template <typename FnT>
using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>())); using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
} // end namespace detail } // end namespace detail
} // namespace mlir } // namespace mlir

View File

@ -125,31 +125,17 @@ struct BlockInfoBuilder {
}; };
} // namespace } // namespace
/// Walks all regions (including nested regions recursively) and invokes the
/// given function for every block.
template <typename FuncT>
static void walkRegions(MutableArrayRef<Region> regions, const FuncT &func) {
for (Region &region : regions)
for (Block &block : region) {
func(block);
// Traverse all nested regions.
for (Operation &operation : block)
walkRegions(operation.getRegions(), func);
}
}
/// Builds the internal liveness block mapping. /// Builds the internal liveness block mapping.
static void buildBlockMapping(MutableArrayRef<Region> regions, static void buildBlockMapping(Operation *operation,
DenseMap<Block *, BlockInfoBuilder> &builders) { DenseMap<Block *, BlockInfoBuilder> &builders) {
llvm::SetVector<Block *> toProcess; llvm::SetVector<Block *> toProcess;
walkRegions(regions, [&](Block &block) { operation->walk([&](Block *block) {
BlockInfoBuilder &builder = BlockInfoBuilder &builder =
builders.try_emplace(&block, &block).first->second; builders.try_emplace(block, block).first->second;
if (builder.updateLiveIn()) if (builder.updateLiveIn())
toProcess.insert(block.pred_begin(), block.pred_end()); toProcess.insert(block->pred_begin(), block->pred_end());
}); });
// Propagate the in and out-value sets (fixpoint iteration) // Propagate the in and out-value sets (fixpoint iteration)
@ -172,14 +158,14 @@ static void buildBlockMapping(MutableArrayRef<Region> regions,
/// Creates a new Liveness analysis that computes liveness information for all /// Creates a new Liveness analysis that computes liveness information for all
/// associated regions. /// associated regions.
Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); } Liveness::Liveness(Operation *op) : operation(op) { build(); }
/// Initializes the internal mappings. /// Initializes the internal mappings.
void Liveness::build(MutableArrayRef<Region> regions) { void Liveness::build() {
// Build internal block mapping. // Build internal block mapping.
DenseMap<Block *, BlockInfoBuilder> builders; DenseMap<Block *, BlockInfoBuilder> builders;
buildBlockMapping(regions, builders); buildBlockMapping(operation, builders);
// Store internal block data. // Store internal block data.
for (auto &entry : builders) { for (auto &entry : builders) {
@ -284,11 +270,11 @@ void Liveness::print(raw_ostream &os) const {
DenseMap<Block *, size_t> blockIds; DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds; DenseMap<Operation *, size_t> operationIds;
DenseMap<Value, size_t> valueIds; DenseMap<Value, size_t> valueIds;
walkRegions(operation->getRegions(), [&](Block &block) { operation->walk([&](Block *block) {
blockIds.insert({&block, blockIds.size()}); blockIds.insert({block, blockIds.size()});
for (BlockArgument argument : block.getArguments()) for (BlockArgument argument : block->getArguments())
valueIds.insert({argument, valueIds.size()}); valueIds.insert({argument, valueIds.size()});
for (Operation &operation : block) { for (Operation &operation : *block) {
operationIds.insert({&operation, operationIds.size()}); operationIds.insert({&operation, operationIds.size()});
for (Value result : operation.getResults()) for (Value result : operation.getResults())
valueIds.insert({result, valueIds.size()}); valueIds.insert({result, valueIds.size()});
@ -318,9 +304,9 @@ void Liveness::print(raw_ostream &os) const {
}; };
// Dump information about in and out values. // Dump information about in and out values.
walkRegions(operation->getRegions(), [&](Block &block) { operation->walk([&](Block *block) {
os << "// - Block: " << blockIds[&block] << "\n"; os << "// - Block: " << blockIds[block] << "\n";
auto liveness = getLiveness(&block); const auto *liveness = getLiveness(block);
os << "// --- LiveIn: "; os << "// --- LiveIn: ";
printValueRefs(liveness->inValues); printValueRefs(liveness->inValues);
os << "\n// --- LiveOut: "; os << "\n// --- LiveOut: ";
@ -329,7 +315,7 @@ void Liveness::print(raw_ostream &os) const {
// Print liveness intervals. // Print liveness intervals.
os << "// --- BeginLiveness"; os << "// --- BeginLiveness";
for (Operation &op : block) { for (Operation &op : *block) {
if (op.getNumResults() < 1) if (op.getNumResults() < 1)
continue; continue;
os << "\n"; os << "\n";

View File

@ -11,31 +11,79 @@
using namespace mlir; using namespace mlir;
/// Walk all of the operations nested under and including the given operations. /// Walk all of the regions/blocks/operations nested under and including the
void detail::walkOperations(Operation *op, /// given operation.
function_ref<void(Operation *op)> callback) { void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
// TODO: This walk should be iterative over the operations. for (auto &region : op->getRegions()) {
for (auto &region : op->getRegions()) callback(&region);
for (auto &block : region) for (auto &block : region) {
// Early increment here in the case where the operation is erased. for (auto &nestedOp : block)
for (auto &nestedOp : llvm::make_early_inc_range(block)) walk(&nestedOp, callback);
walkOperations(&nestedOp, callback); }
}
callback(op);
} }
/// Walk all of the operations nested under and including the given operations. void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
/// This methods walks operations until an interrupt signal is received. for (auto &region : op->getRegions()) {
WalkResult for (auto &block : region) {
detail::walkOperations(Operation *op, callback(&block);
function_ref<WalkResult(Operation *op)> callback) { for (auto &nestedOp : block)
walk(&nestedOp, callback);
}
}
}
void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
// TODO: This walk should be iterative over the operations. // TODO: This walk should be iterative over the operations.
for (auto &region : op->getRegions()) { for (auto &region : op->getRegions()) {
for (auto &block : region) { for (auto &block : region) {
// Early increment here in the case where the operation is erased. // Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block)) for (auto &nestedOp : llvm::make_early_inc_range(block))
if (walkOperations(&nestedOp, callback).wasInterrupted()) walk(&nestedOp, callback);
}
}
callback(op);
}
/// Walk all of the regions/blocks/operations nested under and including the
/// given operation. These functions walk operations until an interrupt result
/// is returned by the callback.
WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Region *op)> callback) {
for (auto &region : op->getRegions()) {
if (callback(&region).wasInterrupted())
return WalkResult::interrupt();
for (auto &block : region) {
for (auto &nestedOp : block)
walk(&nestedOp, callback);
}
}
return WalkResult::advance();
}
WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Block *op)> callback) {
for (auto &region : op->getRegions()) {
for (auto &block : region) {
if (callback(&block).wasInterrupted())
return WalkResult::interrupt();
for (auto &nestedOp : block)
walk(&nestedOp, callback);
}
}
return WalkResult::advance();
}
WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Operation *op)> callback) {
// TODO: This walk should be iterative over the operations.
for (auto &region : op->getRegions()) {
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block)) {
if (walk(&nestedOp, callback).wasInterrupted())
return WalkResult::interrupt(); return WalkResult::interrupt();
}
} }
} }
return callback(op); return callback(op);