Update the symbol utility methods to handle the case of unknown operations.

This enhances the symbol table utility methods to handle the case where an unknown operation may define a symbol table. When walking symbols, we now collect all symbol uses before allowing the user to iterate. This prevents the user from assuming that all symbols are actually known before performing a transformation.

PiperOrigin-RevId: 273651963
This commit is contained in:
River Riddle 2019-10-08 18:38:05 -07:00 committed by A. Unique TensorFlower
parent 7446151236
commit b3a6ae8363
4 changed files with 150 additions and 64 deletions

View File

@ -88,26 +88,58 @@ public:
SymbolRefAttr symbolRef;
};
/// Walk all of the uses, for any symbol, that are nested within the given
/// This class implements a range of SymbolRef uses.
class UseRange {
public:
/// This class implements an iterator over the symbol use range.
class iterator final
: public indexed_accessor_iterator<iterator, const UseRange *,
const SymbolUse> {
public:
const SymbolUse *operator->() const { return &object->uses[index]; }
const SymbolUse &operator*() const { return object->uses[index]; }
private:
iterator(const UseRange *owner, ptrdiff_t it)
: indexed_accessor_iterator<iterator, const UseRange *,
const SymbolUse>(owner, it) {}
/// Allow access to the constructor.
friend class UseRange;
};
/// Contruct a UseRange from a given set of uses.
UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
iterator begin() const { return iterator(this, /*it=*/0); }
iterator end() const { return iterator(this, /*it=*/uses.size()); }
private:
std::vector<SymbolUse> uses;
};
/// Get an iterator range for all of the uses, for any symbol, that are nested
/// within the given operation 'from'. This does not traverse into any nested
/// symbol tables, and will also only return uses on 'from' if it does not
/// also define a symbol table. This function returns None if there are any
/// unknown operations that may potentially be symbol tables.
static Optional<UseRange> getSymbolUses(Operation *from);
/// Get all of the uses of the given symbol that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
static WalkResult
walkSymbolUses(Operation *from, function_ref<WalkResult(SymbolUse)> callback);
/// 'from' if it does not also define a symbol table. This function returns
/// None if there are any unknown operations that may potentially be symbol
/// tables.
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
/// Walk all of the uses of the given symbol that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
static WalkResult
walkSymbolUses(StringRef symbol, Operation *from,
function_ref<WalkResult(SymbolUse)> callback);
/// Return if the given symbol has no uses that are nested within the given
/// operation 'from'. This does not traverse into any nested symbol tables,
/// and will also only count uses on 'from' if it does not also define a
/// symbol table.
static bool symbol_use_empty(StringRef symbol, Operation *from);
/// Return if the given symbol is known to have no uses that are nested within
/// the given operation 'from'. This does not traverse into any nested symbol
/// tables, and will also only count uses on 'from' if it does not also define
/// a symbol table. This function will also return false if there are any
/// unknown operations that may potentially be symbol tables. This doesn't
/// necessarily mean that there are no uses, we just can't convervatively
/// prove it.
static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
private:
MLIRContext *context;

View File

@ -20,6 +20,16 @@
using namespace mlir;
/// Return true if the given operation is unknown and may potentially define a
/// symbol table.
static bool isPotentiallyUnknownSymbolTable(Operation *op) {
return !op->getDialect() && op->getNumRegions() == 1;
}
//===----------------------------------------------------------------------===//
// SymbolTable
//===----------------------------------------------------------------------===//
/// Build a symbol table with the symbols within the given operation.
SymbolTable::SymbolTable(Operation *op) : context(op->getContext()) {
assert(op->hasTrait<OpTrait::SymbolTable>() &&
@ -107,9 +117,15 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
/// nullptr if no valid symbol was found.
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
StringRef symbol) {
while (from && !from->hasTrait<OpTrait::SymbolTable>())
assert(from && "expected valid operation");
while (!from->hasTrait<OpTrait::SymbolTable>()) {
from = from->getParentOp();
return from ? lookupSymbolIn(from, symbol) : nullptr;
// Check that this is a valid op and isn't an unknown symbol table.
if (!from || isPotentiallyUnknownSymbolTable(from))
return nullptr;
}
return lookupSymbolIn(from, symbol);
}
//===----------------------------------------------------------------------===//
@ -236,9 +252,9 @@ walkSymbolRefs(Operation *op,
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
WalkResult
SymbolTable::walkSymbolUses(Operation *from,
function_ref<WalkResult(SymbolUse)> callback) {
static Optional<WalkResult>
walkSymbolUses(Operation *from,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
// If from is not a symbol table, check for uses. A symbol table defines a new
// scope, so we can't walk the attributes from the symbol table op.
if (!from->hasTrait<OpTrait::SymbolTable>()) {
@ -258,6 +274,13 @@ SymbolTable::walkSymbolUses(Operation *from,
if (walkSymbolRefs(&op, callback).wasInterrupted())
return WalkResult::interrupt();
// If this operation has regions, and it as well as its dialect arent't
// registered then conservatively fail. The operation may define a
// symbol table, so we can't opaquely know if we should traverse to find
// nested uses.
if (isPotentiallyUnknownSymbolTable(&op))
return llvm::None;
// If this op defines a new symbol table scope, we can't traverse. Any
// symbol references nested within 'op' are different semantically.
if (!op.hasTrait<OpTrait::SymbolTable>()) {
@ -270,32 +293,53 @@ SymbolTable::walkSymbolUses(Operation *from,
return WalkResult::advance();
}
/// Walk all of the uses, for any symbol, that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
WalkResult
SymbolTable::walkSymbolUses(StringRef symbol, Operation *from,
function_ref<WalkResult(SymbolUse)> callback) {
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
return walkSymbolUses(from, [&](SymbolUse symbolUse) {
if (symbolUse.getSymbolRef() != symbolRefAttr)
return WalkResult::advance();
return callback(std::move(symbolUse));
/// Get an iterator range for all of the uses, for any symbol, that are nested
/// within the given operation 'from'. This does not traverse into any nested
/// symbol tables, and will also only return uses on 'from' if it does not
/// also define a symbol table. This function returns None if there are any
/// unknown operations that may potentially be symbol tables.
auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
std::vector<SymbolUse> uses;
Optional<WalkResult> result = walkSymbolUses(from, [&](SymbolUse symbolUse) {
uses.push_back(symbolUse);
return WalkResult::advance();
});
return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
}
/// Return if the given symbol has no uses that are nested within the given
/// operation 'from'. This does not traverse into any nested symbol tables,
/// and will also only count uses on 'from' if it does not also define a
/// symbol table.
bool SymbolTable::symbol_use_empty(StringRef symbol, Operation *from) {
/// Get all of the uses of the given symbol that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table. This function returns
/// None if there are any unknown operations that may potentially be symbol
/// tables.
auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
-> Optional<UseRange> {
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
std::vector<SymbolUse> uses;
Optional<WalkResult> result = walkSymbolUses(from, [&](SymbolUse symbolUse) {
if (symbolRefAttr == symbolUse.getSymbolRef())
uses.push_back(symbolUse);
return WalkResult::advance();
});
return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
}
/// Return if the given symbol is known to have no uses that are nested within
/// the given operation 'from'. This does not traverse into any nested symbol
/// tables, and will also only count uses on 'from' if it does not also define
/// a symbol table. This function will also return false if there are any
/// unknown operations that may potentially be symbol tables.
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
// Walk all of the symbol uses looking for a reference to 'symbol'.
auto walkResult = walkSymbolUses(from, [&](SymbolUse symbolUse) {
return symbolUse.getSymbolRef() == symbolRefAttr ? WalkResult::interrupt()
: WalkResult::advance();
});
return !walkResult.wasInterrupted();
Optional<WalkResult> walkResult =
walkSymbolUses(from, [&](SymbolUse symbolUse) {
return symbolUse.getSymbolRef() == symbolRefAttr
? WalkResult::interrupt()
: WalkResult::advance();
});
return !walkResult || !walkResult->wasInterrupted();
}

View File

@ -1,5 +1,4 @@
// RUN: mlir-opt %s -test-symbol-uses -verify-diagnostics
// RUN: mlir-opt %s -test-symbol-uses -split-input-file -verify-diagnostics
// Symbol references to the module itself don't affect uses of symbols within
// its table.
@ -27,3 +26,11 @@ module attributes {sym.outside_use = @symbol_foo } {
"foo.op"() {test.nested_reference = @symbol_baz} : () -> ()
}
}
// -----
// expected-remark@+1 {{contains an unknown nested operation that 'may' define a new symbol table}}
func @symbol_bar() {
"foo.possibly_unknown_symbol_table"() ({
}) : () -> ()
}

View File

@ -29,31 +29,34 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
for (FuncOp func : module.getOps<FuncOp>()) {
// Test computing uses on a non symboltable op.
unsigned numUses = 0;
SymbolTable::walkSymbolUses(func, [&](SymbolTable::SymbolUse) {
++numUses;
return WalkResult::advance();
});
if (numUses != 0)
Optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(func);
// Test the conservative failure case.
if (!symbolUses) {
func.emitRemark() << "function contains an unknown nested operation "
"that 'may' define a new symbol table";
return;
}
if (unsigned numUses = llvm::size(*symbolUses))
func.emitRemark() << "function contains " << numUses
<< " nested references";
// Test the functionality of symbol_use_empty.
if (SymbolTable::symbol_use_empty(func.getName(), module)) {
// Test the functionality of symbolKnownUseEmpty.
if (SymbolTable::symbolKnownUseEmpty(func.getName(), module)) {
func.emitRemark() << "function has no uses";
continue;
}
// Test the functionality of walkSymbolUses.
numUses = 0;
SymbolTable::walkSymbolUses(
func.getName(), module, [&](SymbolTable::SymbolUse symbolUse) {
symbolUse.getUser()->emitRemark()
<< "found use of function : " << symbolUse.getSymbolRef();
++numUses;
return WalkResult::advance();
});
func.emitRemark() << "function has " << numUses << " uses";
// Test the functionality of getSymbolUses.
symbolUses = SymbolTable::getSymbolUses(func.getName(), module);
assert(symbolUses.hasValue() && "expected no unknown operations");
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
symbolUse.getUser()->emitRemark()
<< "found use of function : " << symbolUse.getSymbolRef();
}
func.emitRemark() << "function has " << llvm::size(*symbolUses)
<< " uses";
}
}
};