diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h index be9cb7166b8f..3bd298a0fbe7 100644 --- a/mlir/include/mlir/Analysis/Liveness.h +++ b/mlir/include/mlir/Analysis/Liveness.h @@ -86,7 +86,7 @@ public: private: /// Initializes the internal mappings. - void build(MutableArrayRef regions); + void build(); private: /// The operation this analysis was constructed from. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index ca2523050b24..f65967ec6284 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -254,7 +254,7 @@ public: typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - detail::walkOperations(&op, callback); + detail::walk(&op, callback); } /// Walk the operations in the specified [begin, end) range of this block in @@ -265,7 +265,7 @@ public: typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - if (detail::walkOperations(&op, callback).wasInterrupted()) + if (detail::walk(&op, callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index d3dce868ca64..fa54cb608cf5 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -520,7 +520,7 @@ public: /// }); template > RetT walk(FnT &&callback) { - return detail::walkOperations(this, std::forward(callback)); + return detail::walk(this, std::forward(callback)); } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h index 490ba92662a2..cf7b3fa5db1c 100644 --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -21,6 +21,8 @@ namespace mlir { class Diagnostic; class InFlightDiagnostic; class Operation; +class Block; +class Region; /// A utility result that is used to signal if a walk method should be /// interrupted or advance. @@ -61,31 +63,41 @@ decltype(first_argument_type(&F::operator())) first_argument_type(F); template using first_argument = decltype(first_argument_type(std::declval())); -/// Walk all of the operations nested under and including the given operation. -void walkOperations(Operation *op, function_ref callback); +/// Walk all of the regions, blocks, or operations nested under (and including) +/// the given operation. +void walk(Operation *op, function_ref callback); +void walk(Operation *op, function_ref callback); +void walk(Operation *op, function_ref callback); -/// Walk all of the operations nested under and including the given operation. -/// This methods walks operations until an interrupt result is returned by the -/// callback. -WalkResult walkOperations(Operation *op, - function_ref callback); +/// Walk all of the regions, blocks, or operations nested under (and including) +/// the given operation. These functions walk until an interrupt result is +/// returned by the callback. +WalkResult walk(Operation *op, function_ref callback); +WalkResult walk(Operation *op, function_ref callback); +WalkResult walk(Operation *op, function_ref callback); // Below are a set of functions to walk nested operations. Users should favor // the direct `walk` methods on the IR classes(Operation/Block/etc) over these // methods. They are also templated to allow for statically dispatching based // upon the type of the callback function. -/// Walk all of the operations nested under and including the given operation. -/// This method is selected for callbacks that operate on Operation*. +/// Walk all of the regions, blocks, or operations nested under (and including) +/// the given operation. This method is selected for callbacks that operate on +/// Region*, Block*, and Operation*. /// /// Example: +/// op->walk([](Region *r) { ... }); +/// op->walk([](Block *b) { ... }); /// op->walk([](Operation *op) { ... }); template < typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> -typename std::enable_if::value, RetT>::type -walkOperations(Operation *op, FuncTy &&callback) { - return detail::walkOperations(op, function_ref(callback)); +typename std::enable_if::value || + std::is_same::value || + std::is_same::value, + RetT>::type +walk(Operation *op, FuncTy &&callback) { + return walk(op, function_ref(callback)); } /// Walk all of the operations of type 'ArgT' nested under and including the @@ -98,14 +110,16 @@ template < typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value && std::is_same::value, RetT>::type -walkOperations(Operation *op, FuncTy &&callback) { +walk(Operation *op, FuncTy &&callback) { auto wrapperFn = [&](Operation *op) { if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }; - return detail::walkOperations(op, function_ref(wrapperFn)); + return walk(op, function_ref(wrapperFn)); } /// Walk all of the operations of type 'ArgT' nested under and including the @@ -122,20 +136,22 @@ template < typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value && std::is_same::value, RetT>::type -walkOperations(Operation *op, FuncTy &&callback) { +walk(Operation *op, FuncTy &&callback) { auto wrapperFn = [&](Operation *op) { if (auto derivedOp = dyn_cast(op)) return callback(derivedOp); return WalkResult::advance(); }; - return detail::walkOperations(op, function_ref(wrapperFn)); + return walk(op, function_ref(wrapperFn)); } /// Utility to provide the return type of a templated walk method. template -using walkResultType = decltype(walkOperations(nullptr, std::declval())); +using walkResultType = decltype(walk(nullptr, std::declval())); } // end namespace detail } // namespace mlir diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp index 38fb386f8000..4dae386e94b2 100644 --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -125,31 +125,17 @@ struct BlockInfoBuilder { }; } // namespace -/// Walks all regions (including nested regions recursively) and invokes the -/// given function for every block. -template -static void walkRegions(MutableArrayRef regions, const FuncT &func) { - for (Region ®ion : regions) - for (Block &block : region) { - func(block); - - // Traverse all nested regions. - for (Operation &operation : block) - walkRegions(operation.getRegions(), func); - } -} - /// Builds the internal liveness block mapping. -static void buildBlockMapping(MutableArrayRef regions, +static void buildBlockMapping(Operation *operation, DenseMap &builders) { llvm::SetVector toProcess; - walkRegions(regions, [&](Block &block) { + operation->walk([&](Block *block) { BlockInfoBuilder &builder = - builders.try_emplace(&block, &block).first->second; + builders.try_emplace(block, block).first->second; if (builder.updateLiveIn()) - toProcess.insert(block.pred_begin(), block.pred_end()); + toProcess.insert(block->pred_begin(), block->pred_end()); }); // Propagate the in and out-value sets (fixpoint iteration) @@ -172,14 +158,14 @@ static void buildBlockMapping(MutableArrayRef regions, /// Creates a new Liveness analysis that computes liveness information for all /// associated regions. -Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); } +Liveness::Liveness(Operation *op) : operation(op) { build(); } /// Initializes the internal mappings. -void Liveness::build(MutableArrayRef regions) { +void Liveness::build() { // Build internal block mapping. DenseMap builders; - buildBlockMapping(regions, builders); + buildBlockMapping(operation, builders); // Store internal block data. for (auto &entry : builders) { @@ -284,11 +270,11 @@ void Liveness::print(raw_ostream &os) const { DenseMap blockIds; DenseMap operationIds; DenseMap valueIds; - walkRegions(operation->getRegions(), [&](Block &block) { - blockIds.insert({&block, blockIds.size()}); - for (BlockArgument argument : block.getArguments()) + operation->walk([&](Block *block) { + blockIds.insert({block, blockIds.size()}); + for (BlockArgument argument : block->getArguments()) valueIds.insert({argument, valueIds.size()}); - for (Operation &operation : block) { + for (Operation &operation : *block) { operationIds.insert({&operation, operationIds.size()}); for (Value result : operation.getResults()) valueIds.insert({result, valueIds.size()}); @@ -318,9 +304,9 @@ void Liveness::print(raw_ostream &os) const { }; // Dump information about in and out values. - walkRegions(operation->getRegions(), [&](Block &block) { - os << "// - Block: " << blockIds[&block] << "\n"; - auto liveness = getLiveness(&block); + operation->walk([&](Block *block) { + os << "// - Block: " << blockIds[block] << "\n"; + const auto *liveness = getLiveness(block); os << "// --- LiveIn: "; printValueRefs(liveness->inValues); os << "\n// --- LiveOut: "; @@ -329,7 +315,7 @@ void Liveness::print(raw_ostream &os) const { // Print liveness intervals. os << "// --- BeginLiveness"; - for (Operation &op : block) { + for (Operation &op : *block) { if (op.getNumResults() < 1) continue; os << "\n"; diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp index bbccdcbf7592..d03bdb508d37 100644 --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -11,31 +11,79 @@ using namespace mlir; -/// Walk all of the operations nested under and including the given operations. -void detail::walkOperations(Operation *op, - function_ref callback) { - // TODO: This walk should be iterative over the operations. - for (auto ®ion : op->getRegions()) - for (auto &block : region) - // Early increment here in the case where the operation is erased. - for (auto &nestedOp : llvm::make_early_inc_range(block)) - walkOperations(&nestedOp, callback); - - callback(op); +/// Walk all of the regions/blocks/operations nested under and including the +/// given operation. +void detail::walk(Operation *op, function_ref callback) { + for (auto ®ion : op->getRegions()) { + callback(®ion); + for (auto &block : region) { + for (auto &nestedOp : block) + walk(&nestedOp, callback); + } + } } -/// Walk all of the operations nested under and including the given operations. -/// This methods walks operations until an interrupt signal is received. -WalkResult -detail::walkOperations(Operation *op, - function_ref callback) { +void detail::walk(Operation *op, function_ref callback) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + callback(&block); + for (auto &nestedOp : block) + walk(&nestedOp, callback); + } + } +} + +void detail::walk(Operation *op, function_ref callback) { // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) - if (walkOperations(&nestedOp, callback).wasInterrupted()) + walk(&nestedOp, callback); + } + } + callback(op); +} + +/// Walk all of the regions/blocks/operations nested under and including the +/// given operation. These functions walk operations until an interrupt result +/// is returned by the callback. +WalkResult detail::walk(Operation *op, + function_ref callback) { + for (auto ®ion : op->getRegions()) { + if (callback(®ion).wasInterrupted()) + return WalkResult::interrupt(); + for (auto &block : region) { + for (auto &nestedOp : block) + walk(&nestedOp, callback); + } + } + return WalkResult::advance(); +} + +WalkResult detail::walk(Operation *op, + function_ref callback) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + if (callback(&block).wasInterrupted()) + return WalkResult::interrupt(); + for (auto &nestedOp : block) + walk(&nestedOp, callback); + } + } + return WalkResult::advance(); +} + +WalkResult detail::walk(Operation *op, + function_ref callback) { + // TODO: This walk should be iterative over the operations. + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : llvm::make_early_inc_range(block)) { + if (walk(&nestedOp, callback).wasInterrupted()) return WalkResult::interrupt(); + } } } return callback(op);