forked from OSchip/llvm-project
[MLIR] Support walks over regions and blocks
Add specializations for `walk` to allow traversal of regions and blocks. Differential Revision: https://reviews.llvm.org/D90379
This commit is contained in:
parent
8c058dd2d7
commit
dbae3d50f1
|
@ -86,7 +86,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Initializes the internal mappings.
|
/// Initializes the internal mappings.
|
||||||
void build(MutableArrayRef<Region> regions);
|
void build();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// The operation this analysis was constructed from.
|
/// The operation this analysis was constructed from.
|
||||||
|
|
|
@ -254,7 +254,7 @@ public:
|
||||||
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
|
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
|
||||||
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
|
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
|
||||||
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
|
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
|
||||||
detail::walkOperations(&op, callback);
|
detail::walk(&op, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Walk the operations in the specified [begin, end) range of this block in
|
/// Walk the operations in the specified [begin, end) range of this block in
|
||||||
|
@ -265,7 +265,7 @@ public:
|
||||||
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
|
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
|
||||||
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
|
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
|
||||||
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
|
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
|
||||||
if (detail::walkOperations(&op, callback).wasInterrupted())
|
if (detail::walk(&op, callback).wasInterrupted())
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
|
|
|
@ -520,7 +520,7 @@ public:
|
||||||
/// });
|
/// });
|
||||||
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
|
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
|
||||||
RetT walk(FnT &&callback) {
|
RetT walk(FnT &&callback) {
|
||||||
return detail::walkOperations(this, std::forward<FnT>(callback));
|
return detail::walk(this, std::forward<FnT>(callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
|
@ -21,6 +21,8 @@ namespace mlir {
|
||||||
class Diagnostic;
|
class Diagnostic;
|
||||||
class InFlightDiagnostic;
|
class InFlightDiagnostic;
|
||||||
class Operation;
|
class Operation;
|
||||||
|
class Block;
|
||||||
|
class Region;
|
||||||
|
|
||||||
/// A utility result that is used to signal if a walk method should be
|
/// A utility result that is used to signal if a walk method should be
|
||||||
/// interrupted or advance.
|
/// interrupted or advance.
|
||||||
|
@ -61,31 +63,41 @@ decltype(first_argument_type(&F::operator())) first_argument_type(F);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using first_argument = decltype(first_argument_type(std::declval<T>()));
|
using first_argument = decltype(first_argument_type(std::declval<T>()));
|
||||||
|
|
||||||
/// Walk all of the operations nested under and including the given operation.
|
/// Walk all of the regions, blocks, or operations nested under (and including)
|
||||||
void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
|
/// the given operation.
|
||||||
|
void walk(Operation *op, function_ref<void(Region *)> callback);
|
||||||
|
void walk(Operation *op, function_ref<void(Block *)> callback);
|
||||||
|
void walk(Operation *op, function_ref<void(Operation *)> callback);
|
||||||
|
|
||||||
/// Walk all of the operations nested under and including the given operation.
|
/// Walk all of the regions, blocks, or operations nested under (and including)
|
||||||
/// This methods walks operations until an interrupt result is returned by the
|
/// the given operation. These functions walk until an interrupt result is
|
||||||
/// callback.
|
/// returned by the callback.
|
||||||
WalkResult walkOperations(Operation *op,
|
WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
|
||||||
function_ref<WalkResult(Operation *op)> callback);
|
WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
|
||||||
|
WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
|
||||||
|
|
||||||
// Below are a set of functions to walk nested operations. Users should favor
|
// Below are a set of functions to walk nested operations. Users should favor
|
||||||
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
|
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
|
||||||
// methods. They are also templated to allow for statically dispatching based
|
// methods. They are also templated to allow for statically dispatching based
|
||||||
// upon the type of the callback function.
|
// upon the type of the callback function.
|
||||||
|
|
||||||
/// Walk all of the operations nested under and including the given operation.
|
/// Walk all of the regions, blocks, or operations nested under (and including)
|
||||||
/// This method is selected for callbacks that operate on Operation*.
|
/// the given operation. This method is selected for callbacks that operate on
|
||||||
|
/// Region*, Block*, and Operation*.
|
||||||
///
|
///
|
||||||
/// Example:
|
/// Example:
|
||||||
|
/// op->walk([](Region *r) { ... });
|
||||||
|
/// op->walk([](Block *b) { ... });
|
||||||
/// op->walk([](Operation *op) { ... });
|
/// op->walk([](Operation *op) { ... });
|
||||||
template <
|
template <
|
||||||
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||||
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
||||||
typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
|
typename std::enable_if<std::is_same<ArgT, Operation *>::value ||
|
||||||
walkOperations(Operation *op, FuncTy &&callback) {
|
std::is_same<ArgT, Region *>::value ||
|
||||||
return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
|
std::is_same<ArgT, Block *>::value,
|
||||||
|
RetT>::type
|
||||||
|
walk(Operation *op, FuncTy &&callback) {
|
||||||
|
return walk(op, function_ref<RetT(ArgT)>(callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Walk all of the operations of type 'ArgT' nested under and including the
|
/// Walk all of the operations of type 'ArgT' nested under and including the
|
||||||
|
@ -98,14 +110,16 @@ template <
|
||||||
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||||
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
||||||
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
|
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
|
||||||
|
!std::is_same<ArgT, Region *>::value &&
|
||||||
|
!std::is_same<ArgT, Block *>::value &&
|
||||||
std::is_same<RetT, void>::value,
|
std::is_same<RetT, void>::value,
|
||||||
RetT>::type
|
RetT>::type
|
||||||
walkOperations(Operation *op, FuncTy &&callback) {
|
walk(Operation *op, FuncTy &&callback) {
|
||||||
auto wrapperFn = [&](Operation *op) {
|
auto wrapperFn = [&](Operation *op) {
|
||||||
if (auto derivedOp = dyn_cast<ArgT>(op))
|
if (auto derivedOp = dyn_cast<ArgT>(op))
|
||||||
callback(derivedOp);
|
callback(derivedOp);
|
||||||
};
|
};
|
||||||
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
|
return walk(op, function_ref<RetT(Operation *)>(wrapperFn));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Walk all of the operations of type 'ArgT' nested under and including the
|
/// Walk all of the operations of type 'ArgT' nested under and including the
|
||||||
|
@ -122,20 +136,22 @@ template <
|
||||||
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
|
||||||
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
|
||||||
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
|
typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
|
||||||
|
!std::is_same<ArgT, Region *>::value &&
|
||||||
|
!std::is_same<ArgT, Block *>::value &&
|
||||||
std::is_same<RetT, WalkResult>::value,
|
std::is_same<RetT, WalkResult>::value,
|
||||||
RetT>::type
|
RetT>::type
|
||||||
walkOperations(Operation *op, FuncTy &&callback) {
|
walk(Operation *op, FuncTy &&callback) {
|
||||||
auto wrapperFn = [&](Operation *op) {
|
auto wrapperFn = [&](Operation *op) {
|
||||||
if (auto derivedOp = dyn_cast<ArgT>(op))
|
if (auto derivedOp = dyn_cast<ArgT>(op))
|
||||||
return callback(derivedOp);
|
return callback(derivedOp);
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
};
|
};
|
||||||
return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
|
return walk(op, function_ref<RetT(Operation *)>(wrapperFn));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Utility to provide the return type of a templated walk method.
|
/// Utility to provide the return type of a templated walk method.
|
||||||
template <typename FnT>
|
template <typename FnT>
|
||||||
using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
|
using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
|
||||||
} // end namespace detail
|
} // end namespace detail
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -125,31 +125,17 @@ struct BlockInfoBuilder {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/// Walks all regions (including nested regions recursively) and invokes the
|
|
||||||
/// given function for every block.
|
|
||||||
template <typename FuncT>
|
|
||||||
static void walkRegions(MutableArrayRef<Region> regions, const FuncT &func) {
|
|
||||||
for (Region ®ion : regions)
|
|
||||||
for (Block &block : region) {
|
|
||||||
func(block);
|
|
||||||
|
|
||||||
// Traverse all nested regions.
|
|
||||||
for (Operation &operation : block)
|
|
||||||
walkRegions(operation.getRegions(), func);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Builds the internal liveness block mapping.
|
/// Builds the internal liveness block mapping.
|
||||||
static void buildBlockMapping(MutableArrayRef<Region> regions,
|
static void buildBlockMapping(Operation *operation,
|
||||||
DenseMap<Block *, BlockInfoBuilder> &builders) {
|
DenseMap<Block *, BlockInfoBuilder> &builders) {
|
||||||
llvm::SetVector<Block *> toProcess;
|
llvm::SetVector<Block *> toProcess;
|
||||||
|
|
||||||
walkRegions(regions, [&](Block &block) {
|
operation->walk([&](Block *block) {
|
||||||
BlockInfoBuilder &builder =
|
BlockInfoBuilder &builder =
|
||||||
builders.try_emplace(&block, &block).first->second;
|
builders.try_emplace(block, block).first->second;
|
||||||
|
|
||||||
if (builder.updateLiveIn())
|
if (builder.updateLiveIn())
|
||||||
toProcess.insert(block.pred_begin(), block.pred_end());
|
toProcess.insert(block->pred_begin(), block->pred_end());
|
||||||
});
|
});
|
||||||
|
|
||||||
// Propagate the in and out-value sets (fixpoint iteration)
|
// Propagate the in and out-value sets (fixpoint iteration)
|
||||||
|
@ -172,14 +158,14 @@ static void buildBlockMapping(MutableArrayRef<Region> regions,
|
||||||
|
|
||||||
/// Creates a new Liveness analysis that computes liveness information for all
|
/// Creates a new Liveness analysis that computes liveness information for all
|
||||||
/// associated regions.
|
/// associated regions.
|
||||||
Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); }
|
Liveness::Liveness(Operation *op) : operation(op) { build(); }
|
||||||
|
|
||||||
/// Initializes the internal mappings.
|
/// Initializes the internal mappings.
|
||||||
void Liveness::build(MutableArrayRef<Region> regions) {
|
void Liveness::build() {
|
||||||
|
|
||||||
// Build internal block mapping.
|
// Build internal block mapping.
|
||||||
DenseMap<Block *, BlockInfoBuilder> builders;
|
DenseMap<Block *, BlockInfoBuilder> builders;
|
||||||
buildBlockMapping(regions, builders);
|
buildBlockMapping(operation, builders);
|
||||||
|
|
||||||
// Store internal block data.
|
// Store internal block data.
|
||||||
for (auto &entry : builders) {
|
for (auto &entry : builders) {
|
||||||
|
@ -284,11 +270,11 @@ void Liveness::print(raw_ostream &os) const {
|
||||||
DenseMap<Block *, size_t> blockIds;
|
DenseMap<Block *, size_t> blockIds;
|
||||||
DenseMap<Operation *, size_t> operationIds;
|
DenseMap<Operation *, size_t> operationIds;
|
||||||
DenseMap<Value, size_t> valueIds;
|
DenseMap<Value, size_t> valueIds;
|
||||||
walkRegions(operation->getRegions(), [&](Block &block) {
|
operation->walk([&](Block *block) {
|
||||||
blockIds.insert({&block, blockIds.size()});
|
blockIds.insert({block, blockIds.size()});
|
||||||
for (BlockArgument argument : block.getArguments())
|
for (BlockArgument argument : block->getArguments())
|
||||||
valueIds.insert({argument, valueIds.size()});
|
valueIds.insert({argument, valueIds.size()});
|
||||||
for (Operation &operation : block) {
|
for (Operation &operation : *block) {
|
||||||
operationIds.insert({&operation, operationIds.size()});
|
operationIds.insert({&operation, operationIds.size()});
|
||||||
for (Value result : operation.getResults())
|
for (Value result : operation.getResults())
|
||||||
valueIds.insert({result, valueIds.size()});
|
valueIds.insert({result, valueIds.size()});
|
||||||
|
@ -318,9 +304,9 @@ void Liveness::print(raw_ostream &os) const {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Dump information about in and out values.
|
// Dump information about in and out values.
|
||||||
walkRegions(operation->getRegions(), [&](Block &block) {
|
operation->walk([&](Block *block) {
|
||||||
os << "// - Block: " << blockIds[&block] << "\n";
|
os << "// - Block: " << blockIds[block] << "\n";
|
||||||
auto liveness = getLiveness(&block);
|
const auto *liveness = getLiveness(block);
|
||||||
os << "// --- LiveIn: ";
|
os << "// --- LiveIn: ";
|
||||||
printValueRefs(liveness->inValues);
|
printValueRefs(liveness->inValues);
|
||||||
os << "\n// --- LiveOut: ";
|
os << "\n// --- LiveOut: ";
|
||||||
|
@ -329,7 +315,7 @@ void Liveness::print(raw_ostream &os) const {
|
||||||
|
|
||||||
// Print liveness intervals.
|
// Print liveness intervals.
|
||||||
os << "// --- BeginLiveness";
|
os << "// --- BeginLiveness";
|
||||||
for (Operation &op : block) {
|
for (Operation &op : *block) {
|
||||||
if (op.getNumResults() < 1)
|
if (op.getNumResults() < 1)
|
||||||
continue;
|
continue;
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
|
@ -11,31 +11,79 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
/// Walk all of the operations nested under and including the given operations.
|
/// Walk all of the regions/blocks/operations nested under and including the
|
||||||
void detail::walkOperations(Operation *op,
|
/// given operation.
|
||||||
function_ref<void(Operation *op)> callback) {
|
void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
|
||||||
// TODO: This walk should be iterative over the operations.
|
for (auto ®ion : op->getRegions()) {
|
||||||
for (auto ®ion : op->getRegions())
|
callback(®ion);
|
||||||
for (auto &block : region)
|
for (auto &block : region) {
|
||||||
// Early increment here in the case where the operation is erased.
|
for (auto &nestedOp : block)
|
||||||
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
walk(&nestedOp, callback);
|
||||||
walkOperations(&nestedOp, callback);
|
}
|
||||||
|
}
|
||||||
callback(op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Walk all of the operations nested under and including the given operations.
|
void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
|
||||||
/// This methods walks operations until an interrupt signal is received.
|
for (auto ®ion : op->getRegions()) {
|
||||||
WalkResult
|
for (auto &block : region) {
|
||||||
detail::walkOperations(Operation *op,
|
callback(&block);
|
||||||
function_ref<WalkResult(Operation *op)> callback) {
|
for (auto &nestedOp : block)
|
||||||
|
walk(&nestedOp, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
|
||||||
// TODO: This walk should be iterative over the operations.
|
// TODO: This walk should be iterative over the operations.
|
||||||
for (auto ®ion : op->getRegions()) {
|
for (auto ®ion : op->getRegions()) {
|
||||||
for (auto &block : region) {
|
for (auto &block : region) {
|
||||||
// Early increment here in the case where the operation is erased.
|
// Early increment here in the case where the operation is erased.
|
||||||
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
for (auto &nestedOp : llvm::make_early_inc_range(block))
|
||||||
if (walkOperations(&nestedOp, callback).wasInterrupted())
|
walk(&nestedOp, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
callback(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Walk all of the regions/blocks/operations nested under and including the
|
||||||
|
/// given operation. These functions walk operations until an interrupt result
|
||||||
|
/// is returned by the callback.
|
||||||
|
WalkResult detail::walk(Operation *op,
|
||||||
|
function_ref<WalkResult(Region *op)> callback) {
|
||||||
|
for (auto ®ion : op->getRegions()) {
|
||||||
|
if (callback(®ion).wasInterrupted())
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
for (auto &block : region) {
|
||||||
|
for (auto &nestedOp : block)
|
||||||
|
walk(&nestedOp, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
}
|
||||||
|
|
||||||
|
WalkResult detail::walk(Operation *op,
|
||||||
|
function_ref<WalkResult(Block *op)> callback) {
|
||||||
|
for (auto ®ion : op->getRegions()) {
|
||||||
|
for (auto &block : region) {
|
||||||
|
if (callback(&block).wasInterrupted())
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
for (auto &nestedOp : block)
|
||||||
|
walk(&nestedOp, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
}
|
||||||
|
|
||||||
|
WalkResult detail::walk(Operation *op,
|
||||||
|
function_ref<WalkResult(Operation *op)> callback) {
|
||||||
|
// TODO: This walk should be iterative over the operations.
|
||||||
|
for (auto ®ion : op->getRegions()) {
|
||||||
|
for (auto &block : region) {
|
||||||
|
// Early increment here in the case where the operation is erased.
|
||||||
|
for (auto &nestedOp : llvm::make_early_inc_range(block)) {
|
||||||
|
if (walk(&nestedOp, callback).wasInterrupted())
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return callback(op);
|
return callback(op);
|
||||||
|
|
Loading…
Reference in New Issue