NFC: Cleanup the implementation of walkSymbolUses.

Refactor the implementation to be much cleaner by adding a `make_second_range` utility to walk the `second` value of a range of pairs.

PiperOrigin-RevId: 275598985
This commit is contained in:
River Riddle 2019-10-18 21:28:47 -07:00 committed by A. Unique TensorFlower
parent d9db842e68
commit 5f6bdd144a
2 changed files with 45 additions and 68 deletions

View File

@ -24,8 +24,7 @@
#define MLIR_SUPPORT_STLEXTRAS_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/iterator.h"
#include <tuple>
#include "llvm/ADT/STLExtras.h"
namespace mlir {
@ -185,6 +184,15 @@ protected:
ptrdiff_t index;
};
/// Given a container of pairs, return a range over the second elements.
template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
return llvm::map_range(
std::forward<ContainerTy>(c),
[](decltype((*std::begin(c))) elt) -> decltype((elt.second)) {
return elt.second;
});
}
} // end namespace mlir
// Allow tuples to be usable as DenseMap keys.

View File

@ -163,44 +163,6 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//
/// A utility result for walking a nested attribute for symbol uses.
enum HandlerResult {
/// The walk of the containter can continue.
Continue = 0,
/// The walk should recurse into the given attribute, as it is a container.
RecurseNestedAttribute,
/// The walk should end immediately, as an interrupt has been signaled.
Interrupt
};
/// Utility function used to handle a nested attribute during a walk of symbol
/// uses. It returns the above HandlerResult signaling the next action for the
/// walk.
HandlerResult handleAttrDuringSymbolWalk(
Operation *op, Attribute attr,
SmallVectorImpl<std::pair<Attribute, unsigned>> &worklist,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
switch (attr.getKind()) {
/// Check for a nested container attribute, these will also need to be
/// walked.
case StandardAttributes::Array:
case StandardAttributes::Dictionary: {
worklist.push_back({attr, /*index*/ 0});
return HandlerResult::RecurseNestedAttribute;
}
// Invoke the provided callback if we find a symbol use and check for a
// requested interrupt.
case StandardAttributes::SymbolRef: {
SymbolTable::SymbolUse use(op, attr.cast<SymbolRefAttr>());
return callback(use).wasInterrupted() ? HandlerResult::Interrupt
: HandlerResult::Continue;
}
default:
return HandlerResult::Continue;
}
}
/// Walk all of the symbol references within the given operation, invoking the
/// provided callback for each found use.
static WalkResult
@ -215,37 +177,44 @@ walkSymbolRefs(Operation *op,
// attribute list.
SmallVector<std::pair<Attribute, unsigned>, 1> worklist;
worklist.push_back({attrDict, /*index*/ 0});
while (!worklist.empty()) {
// Process the symbol references within the given nested attribute range.
auto processAttrs = [&](unsigned &index, auto attrRange) -> WalkResult {
for (Attribute attr : llvm::drop_begin(attrRange, index)) {
// Make sure to keep the index counter in sync.
++index;
/// Check for a nested container attribute, these will also need to be
/// walked.
if (attr.isa<ArrayAttr>() || attr.isa<DictionaryAttr>()) {
worklist.push_back({attr, /*index*/ 0});
return WalkResult::advance();
}
// Invoke the provided callback if we find a symbol use and check for a
// requested interrupt.
if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>())
if (callback(SymbolTable::SymbolUse(op, symbolRef)).wasInterrupted())
return WalkResult::interrupt();
}
// Pop this container attribute from the worklist.
worklist.pop_back();
return WalkResult::advance();
};
WalkResult result = WalkResult::advance();
do {
Attribute attr = worklist.back().first;
unsigned &index = worklist.back().second;
// Iterate over the given attribute, which is guaranteed to be a container.
HandlerResult handlerResult = HandlerResult::Continue;
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
ArrayRef<Attribute> attrs = arrayAttr.getValue();
unsigned attrSize = attrs.size();
while (index != attrSize)
if ((handlerResult = handleAttrDuringSymbolWalk(op, attrs[index++],
worklist, callback)))
break;
} else {
auto dictAttr = attr.cast<DictionaryAttr>();
ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
unsigned attrSize = attrs.size();
while (index != attrSize)
if ((handlerResult = handleAttrDuringSymbolWalk(
op, attrs[index++].second, worklist, callback)))
break;
}
if (handlerResult == HandlerResult::Interrupt)
return WalkResult::interrupt();
// If we didn't encounter a nested attribute, pop the last item from the
// worklist.
if (handlerResult != HandlerResult::RecurseNestedAttribute)
worklist.pop_back();
}
return WalkResult::advance();
// Process the given attribute, which is guaranteed to be a container.
if (auto dict = attr.dyn_cast<DictionaryAttr>())
result = processAttrs(index, make_second_range(dict.getValue()));
else
result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
} while (!worklist.empty() && !result.wasInterrupted());
return result;
}
/// Walk all of the uses, for any symbol, that are nested within the given