forked from OSchip/llvm-project
Add support for walking the uses of a symbol.
MLIR uses symbol references to model references to many global entities, such as functions/variables/etc. Before this change, there is no way to actually reason about the uses of such entities. This change provides a walker for symbol references(via SymbolTable::walkSymbolUses), as well as 'use_empty' support(via SymbolTable::symbol_use_empty). It also resolves some deficiencies in the LangRef definition of SymbolRefAttr, namely the restrictions on where a SymbolRefAttr can be stored, ArrayAttr and DictionaryAttr, and the relationship with operations containing the SymbolTable trait. PiperOrigin-RevId: 273549331
This commit is contained in:
parent
0dd404e4e1
commit
ac91e67375
|
@ -804,7 +804,7 @@ memref<16x32xf32, #identity, memspace0>
|
|||
// The memref index space is of size %M x %N, while %B1 and %B2 bind to the
|
||||
// symbols s0, s1 respectively of the layout map #tiled_dynamic. Data tiles of
|
||||
// size %B1 x %B2 in the logical space will be stored contiguously in memory.
|
||||
// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2
|
||||
// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2
|
||||
// f32 elements.
|
||||
%T = alloc(%M, %N) [%B1, %B2] : memref<?x?xf32, #tiled_dynamic>
|
||||
|
||||
|
@ -860,7 +860,6 @@ integral. In addition, an index map must specify the size of each of its range
|
|||
dimensions onto which it maps. Index map symbols must be listed in order with
|
||||
symbols for dynamic dimension sizes first, followed by other required symbols.
|
||||
|
||||
|
||||
##### Layout Map
|
||||
|
||||
A layout map is a [semi-affine map](Dialects/Affine.md#semi-affine-maps) which
|
||||
|
@ -1360,7 +1359,22 @@ symbol-ref-attribute ::= symbol-ref-id
|
|||
```
|
||||
|
||||
A symbol reference attribute is a literal attribute that represents a named
|
||||
reference to a given operation.
|
||||
reference to an operation that is nested within an operation with the
|
||||
`OpTrait::SymbolTable` trait. As such, this reference is given meaning by the
|
||||
nearest parent operation containing the `OpTrait::SymbolTable` trait.
|
||||
|
||||
This attribute can only be held internally by
|
||||
[array attributes](#array-attribute) and
|
||||
[dictionary attributes](#dictionary-attribute)(including the top-level operation
|
||||
attribute dictionary), i.e. no other attribute kinds such as Locations or
|
||||
extended attribute kinds. If a reference to a symbol is necessary from outside
|
||||
of the symbol table that the symbol is defined in, a
|
||||
[string attribute](string-attribute) can be used to refer to the symbol name.
|
||||
|
||||
**Rationale:** Given that MLIR models global accesses with symbol references, to
|
||||
enable efficient multi-threading, it becomes difficult to effectively reason
|
||||
about their uses. By restricting the places that can legally hold a symbol
|
||||
reference, we can always opaquely reason about a symbols usage characteristics.
|
||||
|
||||
#### Type Attribute
|
||||
|
||||
|
|
|
@ -53,6 +53,10 @@ public:
|
|||
/// Return the name of the attribute used for symbol names.
|
||||
static StringRef getSymbolAttrName() { return "sym_name"; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Symbol Utilities
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the operation registered with the given symbol name with the
|
||||
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
|
||||
/// with the 'OpTrait::SymbolTable' trait.
|
||||
|
@ -64,6 +68,47 @@ public:
|
|||
/// found.
|
||||
static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
|
||||
|
||||
/// This class represents a specific symbol use.
|
||||
class SymbolUse {
|
||||
public:
|
||||
SymbolUse(Operation *op, SymbolRefAttr symbolRef)
|
||||
: owner(op), symbolRef(symbolRef) {}
|
||||
|
||||
/// Return the operation user of this symbol reference.
|
||||
Operation *getUser() const { return owner; }
|
||||
|
||||
/// Return the symbol reference that this use represents.
|
||||
SymbolRefAttr getSymbolRef() const { return symbolRef; }
|
||||
|
||||
private:
|
||||
/// The operation that this access is held by.
|
||||
Operation *owner;
|
||||
|
||||
/// The symbol reference that this use represents.
|
||||
SymbolRefAttr symbolRef;
|
||||
};
|
||||
|
||||
/// Walk all of the uses, for any symbol, that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables, and will also only return uses on
|
||||
/// 'from' if it does not also define a symbol table.
|
||||
static WalkResult
|
||||
walkSymbolUses(Operation *from, function_ref<WalkResult(SymbolUse)> callback);
|
||||
|
||||
/// Walk all of the uses of the given symbol that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables, and will also only return uses on
|
||||
/// 'from' if it does not also define a symbol table.
|
||||
static WalkResult
|
||||
walkSymbolUses(StringRef symbol, Operation *from,
|
||||
function_ref<WalkResult(SymbolUse)> callback);
|
||||
|
||||
/// Return if the given symbol has no uses that are nested within the given
|
||||
/// operation 'from'. This does not traverse into any nested symbol tables,
|
||||
/// and will also only count uses on 'from' if it does not also define a
|
||||
/// symbol table.
|
||||
static bool symbol_use_empty(StringRef symbol, Operation *from);
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
|
||||
|
|
|
@ -142,3 +142,160 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
|
|||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolTable Trait Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A utility result for walking a nested attribute for symbol uses.
|
||||
enum HandlerResult {
|
||||
/// The walk of the containter can continue.
|
||||
Continue = 0,
|
||||
/// The walk should recurse into the given attribute, as it is a container.
|
||||
RecurseNestedAttribute,
|
||||
/// The walk should end immediately, as an interrupt has been signaled.
|
||||
Interrupt
|
||||
};
|
||||
|
||||
/// Utility function used to handle a nested attribute during a walk of symbol
|
||||
/// uses. It returns the above HandlerResult signaling the next action for the
|
||||
/// walk.
|
||||
HandlerResult handleAttrDuringSymbolWalk(
|
||||
Operation *op, Attribute attr,
|
||||
SmallVectorImpl<std::pair<Attribute, unsigned>> &worklist,
|
||||
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
|
||||
switch (attr.getKind()) {
|
||||
/// Check for a nested container attribute, these will also need to be
|
||||
/// walked.
|
||||
case StandardAttributes::Array:
|
||||
case StandardAttributes::Dictionary: {
|
||||
worklist.push_back({attr, /*index*/ 0});
|
||||
return HandlerResult::RecurseNestedAttribute;
|
||||
}
|
||||
|
||||
// Invoke the provided callback if we find a symbol use and check for a
|
||||
// requested interrupt.
|
||||
case StandardAttributes::SymbolRef: {
|
||||
SymbolTable::SymbolUse use(op, attr.cast<SymbolRefAttr>());
|
||||
return callback(use).wasInterrupted() ? HandlerResult::Interrupt
|
||||
: HandlerResult::Continue;
|
||||
}
|
||||
default:
|
||||
return HandlerResult::Continue;
|
||||
}
|
||||
}
|
||||
|
||||
/// Walk all of the symbol references within the given operation, invoking the
|
||||
/// provided callback for each found use.
|
||||
static WalkResult
|
||||
walkSymbolRefs(Operation *op,
|
||||
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
|
||||
// Check to see if the operation has any attributes.
|
||||
DictionaryAttr attrDict = op->getAttrList().getDictionary();
|
||||
if (!attrDict)
|
||||
return WalkResult::advance();
|
||||
|
||||
// A worklist of a container attribute and the current index into the held
|
||||
// attribute list.
|
||||
SmallVector<std::pair<Attribute, unsigned>, 1> worklist;
|
||||
worklist.push_back({attrDict, /*index*/ 0});
|
||||
while (!worklist.empty()) {
|
||||
Attribute attr = worklist.back().first;
|
||||
unsigned &index = worklist.back().second;
|
||||
|
||||
// Iterate over the given attribute, which is guaranteed to be a container.
|
||||
HandlerResult handlerResult = HandlerResult::Continue;
|
||||
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
||||
ArrayRef<Attribute> attrs = arrayAttr.getValue();
|
||||
unsigned attrSize = attrs.size();
|
||||
while (index != attrSize)
|
||||
if ((handlerResult = handleAttrDuringSymbolWalk(op, attrs[index++],
|
||||
worklist, callback)))
|
||||
break;
|
||||
} else {
|
||||
auto dictAttr = attr.cast<DictionaryAttr>();
|
||||
ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
|
||||
unsigned attrSize = attrs.size();
|
||||
while (index != attrSize)
|
||||
if ((handlerResult = handleAttrDuringSymbolWalk(
|
||||
op, attrs[index++].second, worklist, callback)))
|
||||
break;
|
||||
}
|
||||
if (handlerResult == HandlerResult::Interrupt)
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// If we didn't encounter a nested attribute, pop the last item from the
|
||||
// worklist.
|
||||
if (handlerResult != HandlerResult::RecurseNestedAttribute)
|
||||
worklist.pop_back();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
/// Walk all of the uses, for any symbol, that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables, and will also only return uses on
|
||||
/// 'from' if it does not also define a symbol table.
|
||||
WalkResult
|
||||
SymbolTable::walkSymbolUses(Operation *from,
|
||||
function_ref<WalkResult(SymbolUse)> callback) {
|
||||
// If from is not a symbol table, check for uses. A symbol table defines a new
|
||||
// scope, so we can't walk the attributes from the symbol table op.
|
||||
if (!from->hasTrait<OpTrait::SymbolTable>()) {
|
||||
if (walkSymbolRefs(from, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
SmallVector<Region *, 1> worklist;
|
||||
worklist.reserve(from->getNumRegions());
|
||||
for (Region ®ion : from->getRegions())
|
||||
worklist.push_back(®ion);
|
||||
|
||||
while (!worklist.empty()) {
|
||||
Region *region = worklist.pop_back_val();
|
||||
for (Block &block : *region) {
|
||||
for (Operation &op : block) {
|
||||
if (walkSymbolRefs(&op, callback).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// If this op defines a new symbol table scope, we can't traverse. Any
|
||||
// symbol references nested within 'op' are different semantically.
|
||||
if (!op.hasTrait<OpTrait::SymbolTable>()) {
|
||||
for (Region ®ion : op.getRegions())
|
||||
worklist.push_back(®ion);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
/// Walk all of the uses, for any symbol, that are nested within the given
|
||||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables, and will also only return uses on
|
||||
/// 'from' if it does not also define a symbol table.
|
||||
WalkResult
|
||||
SymbolTable::walkSymbolUses(StringRef symbol, Operation *from,
|
||||
function_ref<WalkResult(SymbolUse)> callback) {
|
||||
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
|
||||
return walkSymbolUses(from, [&](SymbolUse symbolUse) {
|
||||
if (symbolUse.getSymbolRef() != symbolRefAttr)
|
||||
return WalkResult::advance();
|
||||
return callback(std::move(symbolUse));
|
||||
});
|
||||
}
|
||||
|
||||
/// Return if the given symbol has no uses that are nested within the given
|
||||
/// operation 'from'. This does not traverse into any nested symbol tables,
|
||||
/// and will also only count uses on 'from' if it does not also define a
|
||||
/// symbol table.
|
||||
bool SymbolTable::symbol_use_empty(StringRef symbol, Operation *from) {
|
||||
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
|
||||
|
||||
// Walk all of the symbol uses looking for a reference to 'symbol'.
|
||||
auto walkResult = walkSymbolUses(from, [&](SymbolUse symbolUse) {
|
||||
return symbolUse.getSymbolRef() == symbolRefAttr ? WalkResult::interrupt()
|
||||
: WalkResult::advance();
|
||||
});
|
||||
return !walkResult.wasInterrupted();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
// RUN: mlir-opt %s -test-symbol-uses -verify-diagnostics
|
||||
|
||||
|
||||
// Symbol references to the module itself don't affect uses of symbols within
|
||||
// its table.
|
||||
module attributes {sym.outside_use = @symbol_foo } {
|
||||
// expected-remark@+1 {{function has 2 uses}}
|
||||
func @symbol_foo()
|
||||
|
||||
// expected-remark@+3 {{function has no uses}}
|
||||
// expected-remark@+2 {{found use of function : @symbol_foo}}
|
||||
// expected-remark@+1 {{function contains 2 nested references}}
|
||||
func @symbol_bar() attributes {sym.use = @symbol_foo} {
|
||||
// expected-remark@+1 {{found use of function : @symbol_foo}}
|
||||
"foo.op"() {
|
||||
non_symbol_attr,
|
||||
use = [{ nested_symbol = [@symbol_foo]}],
|
||||
z_other_non_symbol_attr
|
||||
} : () -> ()
|
||||
}
|
||||
|
||||
// expected-remark@+1 {{function has 1 use}}
|
||||
func @symbol_baz()
|
||||
|
||||
// expected-remark@+1 {{found use of function : @symbol_baz}}
|
||||
module attributes {test.reference = @symbol_baz} {
|
||||
"foo.op"() {test.nested_reference = @symbol_baz} : () -> ()
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(TestDialect)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
add_llvm_library(MLIRTestIR
|
||||
TestSymbolUses.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
)
|
||||
target_link_libraries(MLIRTestIR
|
||||
MLIRPass
|
||||
)
|
|
@ -0,0 +1,63 @@
|
|||
//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// This is a symbol test pass that tests the symbol uselist functionality
|
||||
/// provided by the symbol table.
|
||||
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
// Test computing uses on a non symboltable op.
|
||||
unsigned numUses = 0;
|
||||
SymbolTable::walkSymbolUses(func, [&](SymbolTable::SymbolUse) {
|
||||
++numUses;
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (numUses != 0)
|
||||
func.emitRemark() << "function contains " << numUses
|
||||
<< " nested references";
|
||||
|
||||
// Test the functionality of symbol_use_empty.
|
||||
if (SymbolTable::symbol_use_empty(func.getName(), module)) {
|
||||
func.emitRemark() << "function has no uses";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Test the functionality of walkSymbolUses.
|
||||
numUses = 0;
|
||||
SymbolTable::walkSymbolUses(
|
||||
func.getName(), module, [&](SymbolTable::SymbolUse symbolUse) {
|
||||
symbolUse.getUser()->emitRemark()
|
||||
<< "found use of function : " << symbolUse.getSymbolRef();
|
||||
++numUses;
|
||||
return WalkResult::advance();
|
||||
});
|
||||
func.emitRemark() << "function has " << numUses << " uses";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
static PassRegistration<SymbolUsesPass> pass("test-symbol-uses",
|
||||
"Test detection of symbol uses");
|
|
@ -42,6 +42,7 @@ set(LIBS
|
|||
MLIRStandardToLLVM
|
||||
MLIRTransforms
|
||||
MLIRTestDialect
|
||||
MLIRTestIR
|
||||
MLIRTestPass
|
||||
MLIRTestTransforms
|
||||
MLIRSupport
|
||||
|
|
Loading…
Reference in New Issue