forked from OSchip/llvm-project
[mlir] Refactor the implementation of Symbol use lists.
Summary: This revision refactors the implementation of the symbol use-list functionality to be a bit cleaner, as well as easier to reason about. Aside from code cleanup, this revision updates the user contract to never recurse into operations if they define a symbol table. The current functionality, which does recurse, makes it difficult to examine the uses held by a symbol table itself. Moving forward users may provide a specific region to examine for uses instead. Differential Revision: https://reviews.llvm.org/D73427
This commit is contained in:
parent
aff4ed7326
commit
ab9e5598cd
|
@ -137,47 +137,47 @@ public:
|
|||
|
||||
/// Get an iterator range for all of the uses, for any symbol, that are nested
|
||||
/// within the given operation 'from'. This does not traverse into any nested
|
||||
/// symbol tables, and will also only return uses on 'from' if it does not
|
||||
/// also define a symbol table. This is because we treat the region as the
|
||||
/// boundary of the symbol table, and not the op itself. This function returns
|
||||
/// None if there are any unknown operations that may potentially be symbol
|
||||
/// tables.
|
||||
/// symbol tables. This function returns None if there are any unknown
|
||||
/// operations that may potentially be symbol tables.
|
||||
static Optional<UseRange> getSymbolUses(Operation *from);
|
||||
static Optional<UseRange> getSymbolUses(Region *from);
|
||||
|
||||
/// Get all of the uses of the given symbol that are nested within the given
|
||||
/// operation 'from'. This does not traverse into any nested symbol tables,
|
||||
/// and will also only return uses on 'from' if it does not also define a
|
||||
/// symbol table. This is because we treat the region as the boundary of the
|
||||
/// symbol table, and not the op itself. This function returns None if there
|
||||
/// are any unknown operations that may potentially be symbol tables.
|
||||
/// operation 'from'. This does not traverse into any nested symbol tables.
|
||||
/// This function returns None if there are any unknown operations that may
|
||||
/// potentially be symbol tables.
|
||||
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
|
||||
static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
|
||||
static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
|
||||
static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
|
||||
|
||||
/// Return if the given symbol is known to have no uses that are nested
|
||||
/// within the given operation 'from'. This does not traverse into any nested
|
||||
/// symbol tables, and will also only count uses on 'from' if it does not also
|
||||
/// define a symbol table. This is because we treat the region as the boundary
|
||||
/// of the symbol table, and not the op itself. This function will also return
|
||||
/// false if there are any unknown operations that may potentially be symbol
|
||||
/// tables. This doesn't necessarily mean that there are no uses, we just
|
||||
/// can't conservatively prove it.
|
||||
/// symbol tables. This function will also return false if there are any
|
||||
/// unknown operations that may potentially be symbol tables. This doesn't
|
||||
/// necessarily mean that there are no uses, we just can't conservatively
|
||||
/// prove it.
|
||||
static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
|
||||
static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
|
||||
static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
|
||||
static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
|
||||
|
||||
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
|
||||
/// provided symbol 'newSymbol' that are nested within the given operation
|
||||
/// 'from'. This does not traverse into any nested symbol tables, and will
|
||||
/// also only replace uses on 'from' if it does not also define a symbol
|
||||
/// table. This is because we treat the region as the boundary of the symbol
|
||||
/// table, and not the op itself. If there are any unknown operations that may
|
||||
/// potentially be symbol tables, no uses are replaced and failure is
|
||||
/// returned.
|
||||
/// 'from'. This does not traverse into any nested symbol tables. If there are
|
||||
/// any unknown operations that may potentially be symbol tables, no uses are
|
||||
/// replaced and failure is returned.
|
||||
LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol,
|
||||
Operation *from);
|
||||
LLVM_NODISCARD static LogicalResult
|
||||
replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
|
||||
Operation *from);
|
||||
LLVM_NODISCARD static LogicalResult
|
||||
replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from);
|
||||
LLVM_NODISCARD static LogicalResult
|
||||
replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
|
||||
Region *from);
|
||||
|
||||
private:
|
||||
Operation *symbolTableOp;
|
||||
|
|
|
@ -401,35 +401,19 @@ static WalkResult walkSymbolRefs(
|
|||
}
|
||||
|
||||
/// Walk all of the uses, for any symbol, that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables, and will also only return uses on
|
||||
/// 'from' if it does not also define a symbol table.
|
||||
/// regions, invoking the provided callback for each. This does not traverse
|
||||
/// into any nested symbol tables.
|
||||
static Optional<WalkResult> walkSymbolUses(
|
||||
Operation *from,
|
||||
MutableArrayRef<Region> regions,
|
||||
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
|
||||
// If from is not a symbol table, check for uses. A symbol table defines a new
|
||||
// scope, so we can't walk the attributes from the symbol table op.
|
||||
if (!from->hasTrait<OpTrait::SymbolTable>()) {
|
||||
if (walkSymbolRefs(from, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
SmallVector<Region *, 1> worklist;
|
||||
worklist.reserve(from->getNumRegions());
|
||||
for (Region ®ion : from->getRegions())
|
||||
worklist.push_back(®ion);
|
||||
|
||||
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
|
||||
while (!worklist.empty()) {
|
||||
Region *region = worklist.pop_back_val();
|
||||
for (Block &block : *region) {
|
||||
for (Block &block : *worklist.pop_back_val()) {
|
||||
for (Operation &op : block) {
|
||||
if (walkSymbolRefs(&op, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// If this operation has regions, and it as well as its dialect aren't
|
||||
// registered then conservatively fail. The operation may define a
|
||||
// symbol table, so we can't opaquely know if we should traverse to find
|
||||
// nested uses.
|
||||
// Check that this isn't a potentially unknown symbol table.
|
||||
if (isPotentiallyUnknownSymbolTable(&op))
|
||||
return llvm::None;
|
||||
|
||||
|
@ -444,16 +428,74 @@ static Optional<WalkResult> walkSymbolUses(
|
|||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
/// Walk all of the uses, for any symbol, that are nested within the given
|
||||
/// operaion 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables.
|
||||
static Optional<WalkResult> walkSymbolUses(
|
||||
Operation *from,
|
||||
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
|
||||
// If this operation has regions, and it, as well as its dialect, isn't
|
||||
// registered then conservatively fail. The operation may define a
|
||||
// symbol table, so we can't opaquely know if we should traverse to find
|
||||
// nested uses.
|
||||
if (isPotentiallyUnknownSymbolTable(from))
|
||||
return llvm::None;
|
||||
|
||||
/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking
|
||||
/// the provided callback at each one with a properly scoped reference to
|
||||
/// 'symbol'. The callback takes as parameters the symbol reference at the
|
||||
/// current scope as well as the top-level operation representing the top of
|
||||
/// that scope.
|
||||
static Optional<WalkResult> walkSymbolScopes(
|
||||
Operation *symbol, Operation *limit,
|
||||
function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
|
||||
StringRef symbolName = SymbolTable::getSymbolName(symbol);
|
||||
// Walk the uses on this operation.
|
||||
if (walkSymbolRefs(from, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// Only recurse if this operation is not a symbol table. A symbol table
|
||||
// defines a new scope, so we can't walk the attributes from within the symbol
|
||||
// table op.
|
||||
if (!from->hasTrait<OpTrait::SymbolTable>())
|
||||
return walkSymbolUses(from->getRegions(), callback);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This class represents a single symbol scope. A symbol scope represents the
|
||||
/// set of operations nested within a symbol table that may reference symbols
|
||||
/// within that table. A symbol scope does not contain the symbol table
|
||||
/// operation itself, just its contained operations. A scope ends at leaf
|
||||
/// operations or another symbol table operation.
|
||||
struct SymbolScope {
|
||||
/// Walk the symbol uses within this scope, invoking the given callback.
|
||||
/// This variant is used when the callback type matches that expected by
|
||||
/// 'walkSymbolUses'.
|
||||
template <typename CallbackT,
|
||||
typename std::enable_if_t<!std::is_same<
|
||||
typename FunctionTraits<CallbackT>::result_t, void>::value> * =
|
||||
nullptr>
|
||||
Optional<WalkResult> walk(CallbackT cback) {
|
||||
if (Region *region = limit.dyn_cast<Region *>())
|
||||
return walkSymbolUses(*region, cback);
|
||||
return walkSymbolUses(limit.get<Operation *>(), cback);
|
||||
}
|
||||
/// This variant is used when the callback type matches a stripped down type:
|
||||
/// void(SymbolTable::SymbolUse use)
|
||||
template <typename CallbackT,
|
||||
typename std::enable_if_t<std::is_same<
|
||||
typename FunctionTraits<CallbackT>::result_t, void>::value> * =
|
||||
nullptr>
|
||||
Optional<WalkResult> walk(CallbackT cback) {
|
||||
return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
|
||||
return cback(use), WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
/// The representation of the symbol within this scope.
|
||||
SymbolRefAttr symbol;
|
||||
|
||||
/// The IR unit representing this scope.
|
||||
llvm::PointerUnion<Operation *, Region *> limit;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
|
||||
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
|
||||
Operation *limit) {
|
||||
StringRef symName = SymbolTable::getSymbolName(symbol);
|
||||
assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
|
||||
|
||||
// Compute the ancestors of 'limit'.
|
||||
|
@ -466,10 +508,10 @@ static Optional<WalkResult> walkSymbolScopes(
|
|||
if (limitAncestor == symbol) {
|
||||
// Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
|
||||
// doesn't support parent references.
|
||||
if (SymbolTable::getNearestSymbolTable(limit) != symbol->getParentOp())
|
||||
return WalkResult::advance();
|
||||
return callback(SymbolRefAttr::get(symbolName, symbol->getContext()),
|
||||
limit);
|
||||
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
|
||||
symbol->getParentOp())
|
||||
return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
|
||||
return {};
|
||||
}
|
||||
|
||||
limitAncestors.insert(limitAncestor);
|
||||
|
@ -486,36 +528,45 @@ static Optional<WalkResult> walkSymbolScopes(
|
|||
// Compute the set of valid nested references for 'symbol' as far up to the
|
||||
// common ancestor as possible.
|
||||
SmallVector<SymbolRefAttr, 2> references;
|
||||
bool collectedAllReferences = succeeded(collectValidReferencesFor(
|
||||
symbol, symbolName, commonAncestor, references));
|
||||
bool collectedAllReferences = succeeded(
|
||||
collectValidReferencesFor(symbol, symName, commonAncestor, references));
|
||||
|
||||
// Handle the case where the common ancestor is 'limit'.
|
||||
if (commonAncestor == limit) {
|
||||
SmallVector<SymbolScope, 2> scopes;
|
||||
|
||||
// Walk each of the ancestors of 'symbol', calling the compute function for
|
||||
// each one.
|
||||
Operation *limitIt = symbol->getParentOp();
|
||||
for (size_t i = 0, e = references.size(); i != e;
|
||||
++i, limitIt = limitIt->getParentOp()) {
|
||||
Optional<WalkResult> callbackResult = callback(references[i], limitIt);
|
||||
if (callbackResult != WalkResult::advance())
|
||||
return callbackResult;
|
||||
assert(limitIt->hasTrait<OpTrait::SymbolTable>());
|
||||
scopes.push_back({references[i], &limitIt->getRegion(0)});
|
||||
}
|
||||
return WalkResult::advance();
|
||||
return scopes;
|
||||
}
|
||||
|
||||
// Otherwise, we just need the symbol reference for 'symbol' that will be
|
||||
// used within 'limit'. This is the last reference in the list we computed
|
||||
// above if we were able to collect all references.
|
||||
if (!collectedAllReferences)
|
||||
return WalkResult::advance();
|
||||
return callback(references.back(), limit);
|
||||
return {};
|
||||
return {{references.back(), limit}};
|
||||
}
|
||||
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
|
||||
Region *limit) {
|
||||
auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
|
||||
|
||||
/// Walk the symbol scopes defined by 'limit' invoking the provided callback.
|
||||
static Optional<WalkResult> walkSymbolScopes(
|
||||
StringRef symbol, Operation *limit,
|
||||
function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
|
||||
return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit);
|
||||
// If we collected some scopes to walk, make sure to constrain the one for
|
||||
// limit to the specific region requested.
|
||||
if (!scopes.empty())
|
||||
scopes.back().limit = limit;
|
||||
return scopes;
|
||||
}
|
||||
template <typename IRUnit>
|
||||
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
|
||||
IRUnit *limit) {
|
||||
return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
|
||||
}
|
||||
|
||||
/// Returns true if the given reference 'SubRef' is a sub reference of the
|
||||
|
@ -539,6 +590,18 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// SymbolTable::getSymbolUses
|
||||
|
||||
/// The implementation of SymbolTable::getSymbolUses below.
|
||||
template <typename FromT>
|
||||
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
|
||||
std::vector<SymbolTable::SymbolUse> uses;
|
||||
auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
|
||||
uses.push_back(symbolUse);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto result = walkSymbolUses(from, walkFn);
|
||||
return result ? Optional<SymbolTable::UseRange>(std::move(uses)) : llvm::None;
|
||||
}
|
||||
|
||||
/// Get an iterator range for all of the uses, for any symbol, that are nested
|
||||
/// within the given operation 'from'. This does not traverse into any nested
|
||||
/// symbol tables, and will also only return uses on 'from' if it does not
|
||||
|
@ -547,43 +610,34 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
|
|||
/// None if there are any unknown operations that may potentially be symbol
|
||||
/// tables.
|
||||
auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
|
||||
std::vector<SymbolUse> uses;
|
||||
auto walkFn = [&](SymbolUse symbolUse, ArrayRef<int>) {
|
||||
uses.push_back(symbolUse);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto result = walkSymbolUses(from, walkFn);
|
||||
return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
|
||||
return getSymbolUsesImpl(from);
|
||||
}
|
||||
auto SymbolTable::getSymbolUses(Region *from) -> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolTable::getSymbolUses
|
||||
|
||||
/// The implementation of SymbolTable::getSymbolUses below.
|
||||
template <typename SymbolT>
|
||||
template <typename SymbolT, typename IRUnitT>
|
||||
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
|
||||
Operation *limit) {
|
||||
IRUnitT *limit) {
|
||||
std::vector<SymbolTable::SymbolUse> uses;
|
||||
auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
|
||||
return walkSymbolUses(
|
||||
from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
|
||||
if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()))
|
||||
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
|
||||
if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
|
||||
if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
|
||||
uses.push_back(symbolUse);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
};
|
||||
if (walkSymbolScopes(symbol, limit, walkFn))
|
||||
return SymbolTable::UseRange(std::move(uses));
|
||||
return llvm::None;
|
||||
}))
|
||||
return llvm::None;
|
||||
}
|
||||
return SymbolTable::UseRange(std::move(uses));
|
||||
}
|
||||
|
||||
/// Get all of the uses of the given symbol that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables, and will also only return uses on
|
||||
/// 'from' if it does not also define a symbol table. This is because we treat
|
||||
/// the region as the boundary of the symbol table, and not the op itself. This
|
||||
/// function returns None if there are any unknown operations that may
|
||||
/// potentially be symbol tables.
|
||||
/// traverse into any nested symbol tables. This function returns None if there
|
||||
/// are any unknown operations that may potentially be symbol tables.
|
||||
auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
|
||||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
|
@ -592,37 +646,49 @@ auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
|
|||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
}
|
||||
auto SymbolTable::getSymbolUses(StringRef symbol, Region *from)
|
||||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
}
|
||||
auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
|
||||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolTable::symbolKnownUseEmpty
|
||||
|
||||
/// The implementation of SymbolTable::symbolKnownUseEmpty below.
|
||||
template <typename SymbolT>
|
||||
static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) {
|
||||
// Walk all of the symbol uses looking for a reference to 'symbol'.
|
||||
auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
|
||||
return walkSymbolUses(
|
||||
from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
|
||||
return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())
|
||||
template <typename SymbolT, typename IRUnitT>
|
||||
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
|
||||
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
|
||||
// Walk all of the symbol uses looking for a reference to 'symbol'.
|
||||
if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
|
||||
return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
|
||||
? WalkResult::interrupt()
|
||||
: WalkResult::advance();
|
||||
});
|
||||
};
|
||||
return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance();
|
||||
}) != WalkResult::advance())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Return if the given symbol is known to have no uses that are nested within
|
||||
/// the given operation 'from'. This does not traverse into any nested symbol
|
||||
/// tables, and will also only count uses on 'from' if it does not also define
|
||||
/// a symbol table. This is because we treat the region as the boundary of the
|
||||
/// symbol table, and not the op itself. This function will also return false if
|
||||
/// there are any unknown operations that may potentially be symbol tables.
|
||||
/// tables. This function will also return false if there are any unknown
|
||||
/// operations that may potentially be symbol tables.
|
||||
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolTable::replaceAllSymbolUses
|
||||
|
@ -685,10 +751,9 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
|
|||
}
|
||||
|
||||
/// The implementation of SymbolTable::replaceAllSymbolUses below.
|
||||
template <typename SymbolT>
|
||||
static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
|
||||
StringRef newSymbol,
|
||||
Operation *limit) {
|
||||
template <typename SymbolT, typename IRUnitT>
|
||||
static LogicalResult
|
||||
replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
|
||||
// A collection of operations along with their new attribute dictionary.
|
||||
std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
|
||||
|
||||
|
@ -710,26 +775,26 @@ static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
|
|||
// Generate a new attribute to replace the given attribute.
|
||||
MLIRContext *ctx = limit->getContext();
|
||||
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
|
||||
auto scopeWalkFn = [&](SymbolRefAttr oldAttr,
|
||||
Operation *from) -> Optional<WalkResult> {
|
||||
SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr);
|
||||
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
|
||||
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
|
||||
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
|
||||
ArrayRef<int> accessChain) {
|
||||
SymbolRefAttr useRef = symbolUse.getSymbolRef();
|
||||
if (!isReferencePrefixOf(oldAttr, useRef))
|
||||
if (!isReferencePrefixOf(scope.symbol, useRef))
|
||||
return WalkResult::advance();
|
||||
|
||||
// If we have a valid match, check to see if this is a proper
|
||||
// subreference. If it is, then we will need to generate a different new
|
||||
// attribute specifically for this use.
|
||||
SymbolRefAttr replacementRef = newAttr;
|
||||
if (useRef != oldAttr) {
|
||||
if (oldAttr.isa<FlatSymbolRefAttr>()) {
|
||||
if (useRef != scope.symbol) {
|
||||
if (scope.symbol.isa<FlatSymbolRefAttr>()) {
|
||||
replacementRef =
|
||||
SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
|
||||
} else {
|
||||
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
|
||||
nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr;
|
||||
nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
|
||||
newLeafAttr;
|
||||
replacementRef =
|
||||
SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
|
||||
}
|
||||
|
@ -748,18 +813,15 @@ static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
|
|||
accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
|
||||
return WalkResult::advance();
|
||||
};
|
||||
if (!walkSymbolUses(from, walkFn))
|
||||
return llvm::None;
|
||||
if (!scope.walk(walkFn))
|
||||
return failure();
|
||||
|
||||
// Check to see if we have a dangling op that needs to be processed.
|
||||
if (curOp) {
|
||||
updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
|
||||
curOp = nullptr;
|
||||
}
|
||||
return WalkResult::advance();
|
||||
};
|
||||
if (!walkSymbolScopes(symbol, limit, scopeWalkFn))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Update the attribute dictionaries as necessary.
|
||||
for (auto &it : updatedAttrDicts)
|
||||
|
@ -769,11 +831,9 @@ static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
|
|||
|
||||
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
|
||||
/// provided symbol 'newSymbol' that are nested within the given operation
|
||||
/// 'from'. This does not traverse into any nested symbol tables, and will
|
||||
/// also only replace uses on 'from' if it does not also define a symbol
|
||||
/// table. This is because we treat the region as the boundary of the symbol
|
||||
/// table, and not the op itself. If there are any unknown operations that may
|
||||
/// potentially be symbol tables, no uses are replaced and failure is returned.
|
||||
/// 'from'. This does not traverse into any nested symbol tables. If there are
|
||||
/// any unknown operations that may potentially be symbol tables, no uses are
|
||||
/// replaced and failure is returned.
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol,
|
||||
Operation *from) {
|
||||
|
@ -784,3 +844,13 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
|
|||
Operation *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol,
|
||||
Region *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
|
||||
StringRef newSymbol,
|
||||
Region *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace {
|
|||
/// This is a symbol test pass that tests the symbol uselist functionality
|
||||
/// provided by the symbol table along with erasing from the symbol table.
|
||||
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
||||
WalkResult operateOnSymbol(Operation *symbol, Operation *module,
|
||||
WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
|
||||
SmallVectorImpl<FuncOp> &deadFunctions) {
|
||||
// Test computing uses on a non symboltable op.
|
||||
Optional<SymbolTable::UseRange> symbolUses =
|
||||
|
@ -34,7 +34,7 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
<< " nested references";
|
||||
|
||||
// Test the functionality of symbolKnownUseEmpty.
|
||||
if (SymbolTable::symbolKnownUseEmpty(symbol, module)) {
|
||||
if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) {
|
||||
FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
|
||||
if (funcSymbol && funcSymbol.isExternal())
|
||||
deadFunctions.push_back(funcSymbol);
|
||||
|
@ -44,7 +44,7 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
}
|
||||
|
||||
// Test the functionality of getSymbolUses.
|
||||
symbolUses = SymbolTable::getSymbolUses(symbol, module);
|
||||
symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion());
|
||||
assert(symbolUses.hasValue() && "expected no unknown operations");
|
||||
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
|
||||
// Check that we can resolve back to our symbol.
|
||||
|
@ -70,10 +70,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
SymbolTable table(module);
|
||||
for (Operation *op : deadFunctions) {
|
||||
// In order to test the SymbolTable::erase method, also erase completely
|
||||
// useless functions.
|
||||
SymbolTable table(module);
|
||||
auto name = SymbolTable::getSymbolName(op);
|
||||
assert(table.lookup(name) && "expected no unknown operations");
|
||||
table.erase(op);
|
||||
|
@ -96,7 +96,7 @@ struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
|
|||
if (!newName)
|
||||
return;
|
||||
if (succeeded(SymbolTable::replaceAllSymbolUses(
|
||||
nestedOp, newName.getValue(), module)))
|
||||
nestedOp, newName.getValue(), &module.getBodyRegion())))
|
||||
SymbolTable::setSymbolName(nestedOp, newName.getValue());
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue