[mlir][IR] Add a Region::getOps method that returns a range of immediately nested operations

This allows for walking the operations nested directly within a region, without traversing nested regions.

Differential Revision: https://reviews.llvm.org/D79056
This commit is contained in:
River Riddle 2020-05-04 17:46:06 -07:00
parent 6bce7d8d67
commit 1e4faf23ff
10 changed files with 234 additions and 119 deletions

View File

@ -156,54 +156,23 @@ public:
/// Recomputes the ordering of child operations within the block. /// Recomputes the ordering of child operations within the block.
void recomputeOpOrder(); void recomputeOpOrder();
private:
/// A utility iterator that filters out operations that are not 'OpT'.
template <typename OpT>
class op_filter_iterator
: public llvm::filter_iterator<Block::iterator, bool (*)(Operation &)> {
static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
public:
op_filter_iterator(Block::iterator it, Block::iterator end)
: llvm::filter_iterator<Block::iterator, bool (*)(Operation &)>(
it, end, &filter) {}
/// Allow implicit conversion to the underlying block iterator.
operator Block::iterator() const { return this->wrapped(); }
};
public:
/// This class provides iteration over the held operations of a block for a /// This class provides iteration over the held operations of a block for a
/// specific operation type. /// specific operation type.
template <typename OpT> template <typename OpT>
class op_iterator : public llvm::mapped_iterator<op_filter_iterator<OpT>, using op_iterator = detail::op_iterator<OpT, iterator>;
OpT (*)(Operation &)> {
static OpT unwrap(Operation &op) { return cast<OpT>(op); }
public:
using reference = OpT;
/// Initializes the iterator to the specified filter iterator.
op_iterator(op_filter_iterator<OpT> it)
: llvm::mapped_iterator<op_filter_iterator<OpT>, OpT (*)(Operation &)>(
it, &unwrap) {}
/// Allow implicit conversion to the underlying block iterator.
operator Block::iterator() const { return this->wrapped(); }
};
/// Return an iterator range over the operations within this block that are of /// Return an iterator range over the operations within this block that are of
/// 'OpT'. /// 'OpT'.
template <typename OpT> iterator_range<op_iterator<OpT>> getOps() { template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
auto endIt = end(); auto endIt = end();
return {op_filter_iterator<OpT>(begin(), endIt), return {detail::op_filter_iterator<OpT, iterator>(begin(), endIt),
op_filter_iterator<OpT>(endIt, endIt)}; detail::op_filter_iterator<OpT, iterator>(endIt, endIt)};
} }
template <typename OpT> op_iterator<OpT> op_begin() { template <typename OpT> op_iterator<OpT> op_begin() {
return op_filter_iterator<OpT>(begin(), end()); return detail::op_filter_iterator<OpT, iterator>(begin(), end());
} }
template <typename OpT> op_iterator<OpT> op_end() { template <typename OpT> op_iterator<OpT> op_end() {
return op_filter_iterator<OpT>(end(), end()); return detail::op_filter_iterator<OpT, iterator>(end(), end());
} }
/// Return an iterator range over the operation within this block excluding /// Return an iterator range over the operation within this block excluding

View File

@ -75,6 +75,46 @@ private:
friend RangeBaseT; friend RangeBaseT;
}; };
//===----------------------------------------------------------------------===//
// Operation Iterators
//===----------------------------------------------------------------------===//
namespace detail {
/// A utility iterator that filters out operations that are not 'OpT'.
template <typename OpT, typename IteratorT>
class op_filter_iterator
: public llvm::filter_iterator<IteratorT, bool (*)(Operation &)> {
static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
public:
op_filter_iterator(IteratorT it, IteratorT end)
: llvm::filter_iterator<IteratorT, bool (*)(Operation &)>(it, end,
&filter) {}
/// Allow implicit conversion to the underlying iterator.
operator IteratorT() const { return this->wrapped(); }
};
/// This class provides iteration over the held operations of a block for a
/// specific operation type.
template <typename OpT, typename IteratorT>
class op_iterator
: public llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
OpT (*)(Operation &)> {
static OpT unwrap(Operation &op) { return cast<OpT>(op); }
public:
using reference = OpT;
/// Initializes the iterator to the specified filter iterator.
op_iterator(op_filter_iterator<OpT, IteratorT> it)
: llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
OpT (*)(Operation &)>(it, &unwrap) {}
/// Allow implicit conversion to the underlying block iterator.
operator IteratorT() const { return this->wrapped(); }
};
} // end namespace detail
} // end namespace mlir } // end namespace mlir
namespace llvm { namespace llvm {

View File

@ -32,9 +32,10 @@ namespace mlir {
/// symbols referenced by name via a string attribute). /// symbols referenced by name via a string attribute).
class FuncOp class FuncOp
: public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult, : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike, OpTrait::OneRegion, OpTrait::IsIsolatedFromAbove,
OpTrait::AutomaticAllocationScope, OpTrait::PolyhedralScope, OpTrait::FunctionLike, OpTrait::AutomaticAllocationScope,
CallableOpInterface::Trait, SymbolOpInterface::Trait> { OpTrait::PolyhedralScope, CallableOpInterface::Trait,
SymbolOpInterface::Trait> {
public: public:
using Op::Op; using Op::Op;
using Op::print; using Op::print;

View File

@ -583,6 +583,13 @@ class OneRegion : public TraitBase<ConcreteType, OneRegion> {
public: public:
Region &getRegion() { return this->getOperation()->getRegion(0); } Region &getRegion() { return this->getOperation()->getRegion(0); }
/// Returns a range of operations within the region of this operation.
auto getOps() { return getRegion().getOps(); }
template <typename OpT>
auto getOps() {
return getRegion().template getOps<OpT>();
}
static LogicalResult verifyTrait(Operation *op) { static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneRegion(op); return impl::verifyOneRegion(op);
} }

View File

@ -34,6 +34,10 @@ public:
/// parent container. The region must have a valid parent container. /// parent container. The region must have a valid parent container.
Location getLoc(); Location getLoc();
//===--------------------------------------------------------------------===//
// Block list management
//===--------------------------------------------------------------------===//
using BlockListType = llvm::iplist<Block>; using BlockListType = llvm::iplist<Block>;
BlockListType &getBlocks() { return blocks; } BlockListType &getBlocks() { return blocks; }
@ -58,6 +62,72 @@ public:
return &Region::blocks; return &Region::blocks;
} }
//===--------------------------------------------------------------------===//
// Operation list utilities
//===--------------------------------------------------------------------===//
/// This class provides iteration over the held operations of blocks directly
/// within a region.
class OpIterator final
: public llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
Operation> {
public:
/// Initialize OpIterator for a region, specify `end` to return the iterator
/// to last operation.
explicit OpIterator(Region *region, bool end = false);
using llvm::iterator_facade_base<OpIterator, std::forward_iterator_tag,
Operation>::operator++;
OpIterator &operator++();
Operation *operator->() const { return &*operation; }
Operation &operator*() const { return *operation; }
/// Compare this iterator with another.
bool operator==(const OpIterator &rhs) const {
return operation == rhs.operation;
}
bool operator!=(const OpIterator &rhs) const { return !(*this == rhs); }
private:
void skipOverBlocksWithNoOps();
/// The region whose operations are being iterated over.
Region *region;
/// The block of 'region' whose operations are being iterated over.
Region::iterator block;
/// The current operation within 'block'.
Block::iterator operation;
};
/// This class provides iteration over the held operations of a region for a
/// specific operation type.
template <typename OpT>
using op_iterator = detail::op_iterator<OpT, OpIterator>;
/// Return iterators that walk the operations nested directly within this
/// region.
OpIterator op_begin() { return OpIterator(this); }
OpIterator op_end() { return OpIterator(this, /*end=*/true); }
iterator_range<OpIterator> getOps() { return {op_begin(), op_end()}; }
/// Return iterators that walk operations of type 'T' nested directly within
/// this region.
template <typename OpT> op_iterator<OpT> op_begin() {
return detail::op_filter_iterator<OpT, OpIterator>(op_begin(), op_end());
}
template <typename OpT> op_iterator<OpT> op_end() {
return detail::op_filter_iterator<OpT, OpIterator>(op_end(), op_end());
}
template <typename OpT> iterator_range<op_iterator<OpT>> getOps() {
auto endIt = op_end();
return {detail::op_filter_iterator<OpT, OpIterator>(op_begin(), endIt),
detail::op_filter_iterator<OpT, OpIterator>(endIt, endIt)};
}
//===--------------------------------------------------------------------===//
// Misc. utilities
//===--------------------------------------------------------------------===//
/// Return the region containing this region or nullptr if the region is /// Return the region containing this region or nullptr if the region is
/// attached to a top-level operation. /// attached to a top-level operation.
Region *getParentRegion(); Region *getParentRegion();
@ -120,6 +190,10 @@ public:
/// they are to be deleted. /// they are to be deleted.
void dropAllReferences(); void dropAllReferences();
//===--------------------------------------------------------------------===//
// Operation Walkers
//===--------------------------------------------------------------------===//
/// Walk the operations in this region in postorder, calling the callback for /// Walk the operations in this region in postorder, calling the callback for
/// each operation. This method is invoked for void-returning callbacks. /// each operation. This method is invoked for void-returning callbacks.
/// See Operation::walk for more details. /// See Operation::walk for more details.
@ -142,6 +216,10 @@ public:
return WalkResult::advance(); return WalkResult::advance();
} }
//===--------------------------------------------------------------------===//
// CFG view utilities
//===--------------------------------------------------------------------===//
/// Displays the CFG in a window. This is for use from the debugger and /// Displays the CFG in a window. This is for use from the debugger and
/// depends on Graphviz to generate the graph. /// depends on Graphviz to generate the graph.
/// This function is defined in ViewRegionGraph and only works with that /// This function is defined in ViewRegionGraph and only works with that

View File

@ -87,8 +87,7 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
} }
for (Region &region : op->getRegions()) for (Region &region : op->getRegions())
for (Block &block : region) for (Operation &nested : region.getOps())
for (Operation &nested : block)
computeCallGraph(&nested, cg, parentNode, resolveCalls); computeCallGraph(&nested, cg, parentNode, resolveCalls);
} }

View File

@ -36,15 +36,14 @@ struct ForLoopMapper : public ConvertSimpleLoopsToGPUBase<ForLoopMapper> {
} }
void runOnFunction() override { void runOnFunction() override {
for (Block &block : getFunction()) for (Operation &op : llvm::make_early_inc_range(getFunction().getOps())) {
for (Operation &op : llvm::make_early_inc_range(block)) {
if (auto forOp = dyn_cast<AffineForOp>(&op)) { if (auto forOp = dyn_cast<AffineForOp>(&op)) {
if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims, if (failed(convertAffineLoopNestToGPULaunch(forOp, numBlockDims,
numThreadDims))) numThreadDims)))
signalPassFailure(); signalPassFailure();
} else if (auto forOp = dyn_cast<ForOp>(&op)) { } else if (auto forOp = dyn_cast<ForOp>(&op)) {
if (failed(convertLoopNestToGPULaunch(forOp, numBlockDims, if (failed(
numThreadDims))) convertLoopNestToGPULaunch(forOp, numBlockDims, numThreadDims)))
signalPassFailure(); signalPassFailure();
} }
} }
@ -81,17 +80,13 @@ struct ImperfectlyNestedForLoopMapper
funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val)); funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
workGroupSizeVal.push_back(constOp); workGroupSizeVal.push_back(constOp);
} }
for (Block &block : getFunction()) { for (ForOp forOp : llvm::make_early_inc_range(funcOp.getOps<ForOp>())) {
for (Operation &op : llvm::make_early_inc_range(block)) {
if (auto forOp = dyn_cast<ForOp>(&op)) {
if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal, if (failed(convertLoopToGPULaunch(forOp, numWorkGroupsVal,
workGroupSizeVal))) { workGroupSizeVal))) {
return signalPassFailure(); return signalPassFailure();
} }
} }
} }
}
}
}; };
struct ParallelLoopToGpuPass struct ParallelLoopToGpuPass

View File

@ -146,8 +146,7 @@ static bool isIsolatedAbove(Region &region, Region &limit,
// Traverse all operations in the region. // Traverse all operations in the region.
while (!pendingRegions.empty()) { while (!pendingRegions.empty()) {
for (Block &block : *pendingRegions.pop_back_val()) { for (Operation &op : pendingRegions.pop_back_val()->getOps()) {
for (Operation &op : block) {
for (Value operand : op.getOperands()) { for (Value operand : op.getOperands()) {
// operand should be non-null here if the IR is well-formed. But // operand should be non-null here if the IR is well-formed. But
// we don't assert here as this function is called from the verifier // we don't assert here as this function is called from the verifier
@ -175,7 +174,6 @@ static bool isIsolatedAbove(Region &region, Region &limit,
pendingRegions.push_back(&subRegion); pendingRegions.push_back(&subRegion);
} }
} }
}
return true; return true;
} }
@ -219,6 +217,40 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
first->parentValidOpOrderPair.setPointer(curParent); first->parentValidOpOrderPair.setPointer(curParent);
} }
//===----------------------------------------------------------------------===//
// Region::OpIterator
//===----------------------------------------------------------------------===//
Region::OpIterator::OpIterator(Region *region, bool end)
: region(region), block(end ? region->end() : region->begin()) {
if (!region->empty())
skipOverBlocksWithNoOps();
}
Region::OpIterator &Region::OpIterator::operator++() {
// We increment over operations, if we reach the last use then move to next
// block.
if (operation != block->end())
++operation;
if (operation == block->end()) {
++block;
skipOverBlocksWithNoOps();
}
return *this;
}
void Region::OpIterator::skipOverBlocksWithNoOps() {
while (block != region->end() && block->empty())
++block;
// If we are at the last block, then set the operation to first operation of
// next block (sentinel value used for end).
if (block == region->end())
operation = {};
else
operation = block->begin();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RegionRange // RegionRange
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -245,11 +245,9 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
// Look for a symbol with the given name. // Look for a symbol with the given name.
for (auto &block : symbolTableOp->getRegion(0)) { for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
for (auto &op : block)
if (getNameIfSymbol(&op) == symbol) if (getNameIfSymbol(&op) == symbol)
return &op; return &op;
}
return nullptr; return nullptr;
} }
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
@ -444,8 +442,7 @@ static Optional<WalkResult> walkSymbolUses(
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) { function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions)); SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
while (!worklist.empty()) { while (!worklist.empty()) {
for (Block &block : *worklist.pop_back_val()) { for (Operation &op : worklist.pop_back_val()->getOps()) {
for (Operation &op : block) {
if (walkSymbolRefs(&op, callback).wasInterrupted()) if (walkSymbolRefs(&op, callback).wasInterrupted())
return WalkResult::interrupt(); return WalkResult::interrupt();
@ -461,7 +458,6 @@ static Optional<WalkResult> walkSymbolUses(
} }
} }
} }
}
return WalkResult::advance(); return WalkResult::advance();
} }
/// Walk all of the uses, for any symbol, that are nested within the given /// Walk all of the uses, for any symbol, that are nested within the given

View File

@ -122,8 +122,7 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
// Walk each of the symbol tables looking for discardable callgraph nodes. // Walk each of the symbol tables looking for discardable callgraph nodes.
auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
for (Block &block : symbolTableOp->getRegion(0)) { for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
for (Operation &op : block) {
// If this is a callgraph operation, check to see if it is discardable. // If this is a callgraph operation, check to see if it is discardable.
if (auto callable = dyn_cast<CallableOpInterface>(&op)) { if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
if (auto *node = cg.lookupNode(callable.getCallableRegion())) { if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
@ -139,7 +138,6 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
[](CallGraphNode *, Operation *) {}); [](CallGraphNode *, Operation *) {});
} }
}
}; };
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
walkFn); walkFn);