Add support for walking the uses of a symbol.

MLIR uses symbol references to model references to many global entities, such as functions/variables/etc. Before this change, there is no way to actually reason about the uses of such entities. This change provides a walker for symbol references(via SymbolTable::walkSymbolUses), as well as 'use_empty' support(via SymbolTable::symbol_use_empty). It also resolves some deficiencies in the LangRef definition of SymbolRefAttr, namely the restrictions on where a SymbolRefAttr can be stored, ArrayAttr and DictionaryAttr, and the relationship with operations containing the SymbolTable trait.

PiperOrigin-RevId: 273549331
This commit is contained in:
River Riddle 2019-10-08 10:21:26 -07:00 committed by A. Unique TensorFlower
parent 0dd404e4e1
commit ac91e67375
8 changed files with 321 additions and 3 deletions

View File

@ -804,7 +804,7 @@ memref<16x32xf32, #identity, memspace0>
// The memref index space is of size %M x %N, while %B1 and %B2 bind to the
// symbols s0, s1 respectively of the layout map #tiled_dynamic. Data tiles of
// size %B1 x %B2 in the logical space will be stored contiguously in memory.
// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2
// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2
// f32 elements.
%T = alloc(%M, %N) [%B1, %B2] : memref<?x?xf32, #tiled_dynamic>
@ -860,7 +860,6 @@ integral. In addition, an index map must specify the size of each of its range
dimensions onto which it maps. Index map symbols must be listed in order with
symbols for dynamic dimension sizes first, followed by other required symbols.
##### Layout Map
A layout map is a [semi-affine map](Dialects/Affine.md#semi-affine-maps) which
@ -1360,7 +1359,22 @@ symbol-ref-attribute ::= symbol-ref-id
```
A symbol reference attribute is a literal attribute that represents a named
reference to a given operation.
reference to an operation that is nested within an operation with the
`OpTrait::SymbolTable` trait. As such, this reference is given meaning by the
nearest parent operation containing the `OpTrait::SymbolTable` trait.
This attribute can only be held internally by
[array attributes](#array-attribute) and
[dictionary attributes](#dictionary-attribute)(including the top-level operation
attribute dictionary), i.e. no other attribute kinds such as Locations or
extended attribute kinds. If a reference to a symbol is necessary from outside
of the symbol table that the symbol is defined in, a
[string attribute](string-attribute) can be used to refer to the symbol name.
**Rationale:** Given that MLIR models global accesses with symbol references, to
enable efficient multi-threading, it becomes difficult to effectively reason
about their uses. By restricting the places that can legally hold a symbol
reference, we can always opaquely reason about a symbols usage characteristics.
#### Type Attribute

View File

@ -53,6 +53,10 @@ public:
/// Return the name of the attribute used for symbol names.
static StringRef getSymbolAttrName() { return "sym_name"; }
//===--------------------------------------------------------------------===//
// Symbol Utilities
//===--------------------------------------------------------------------===//
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait.
@ -64,6 +68,47 @@ public:
/// found.
static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
/// This class represents a specific symbol use.
class SymbolUse {
public:
SymbolUse(Operation *op, SymbolRefAttr symbolRef)
: owner(op), symbolRef(symbolRef) {}
/// Return the operation user of this symbol reference.
Operation *getUser() const { return owner; }
/// Return the symbol reference that this use represents.
SymbolRefAttr getSymbolRef() const { return symbolRef; }
private:
/// The operation that this access is held by.
Operation *owner;
/// The symbol reference that this use represents.
SymbolRefAttr symbolRef;
};
/// Walk all of the uses, for any symbol, that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
static WalkResult
walkSymbolUses(Operation *from, function_ref<WalkResult(SymbolUse)> callback);
/// Walk all of the uses of the given symbol that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
static WalkResult
walkSymbolUses(StringRef symbol, Operation *from,
function_ref<WalkResult(SymbolUse)> callback);
/// Return if the given symbol has no uses that are nested within the given
/// operation 'from'. This does not traverse into any nested symbol tables,
/// and will also only count uses on 'from' if it does not also define a
/// symbol table.
static bool symbol_use_empty(StringRef symbol, Operation *from);
private:
MLIRContext *context;

View File

@ -142,3 +142,160 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
}
return success();
}
//===----------------------------------------------------------------------===//
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//
/// A utility result for walking a nested attribute for symbol uses.
enum HandlerResult {
/// The walk of the containter can continue.
Continue = 0,
/// The walk should recurse into the given attribute, as it is a container.
RecurseNestedAttribute,
/// The walk should end immediately, as an interrupt has been signaled.
Interrupt
};
/// Utility function used to handle a nested attribute during a walk of symbol
/// uses. It returns the above HandlerResult signaling the next action for the
/// walk.
HandlerResult handleAttrDuringSymbolWalk(
Operation *op, Attribute attr,
SmallVectorImpl<std::pair<Attribute, unsigned>> &worklist,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
switch (attr.getKind()) {
/// Check for a nested container attribute, these will also need to be
/// walked.
case StandardAttributes::Array:
case StandardAttributes::Dictionary: {
worklist.push_back({attr, /*index*/ 0});
return HandlerResult::RecurseNestedAttribute;
}
// Invoke the provided callback if we find a symbol use and check for a
// requested interrupt.
case StandardAttributes::SymbolRef: {
SymbolTable::SymbolUse use(op, attr.cast<SymbolRefAttr>());
return callback(use).wasInterrupted() ? HandlerResult::Interrupt
: HandlerResult::Continue;
}
default:
return HandlerResult::Continue;
}
}
/// Walk all of the symbol references within the given operation, invoking the
/// provided callback for each found use.
static WalkResult
walkSymbolRefs(Operation *op,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
// Check to see if the operation has any attributes.
DictionaryAttr attrDict = op->getAttrList().getDictionary();
if (!attrDict)
return WalkResult::advance();
// A worklist of a container attribute and the current index into the held
// attribute list.
SmallVector<std::pair<Attribute, unsigned>, 1> worklist;
worklist.push_back({attrDict, /*index*/ 0});
while (!worklist.empty()) {
Attribute attr = worklist.back().first;
unsigned &index = worklist.back().second;
// Iterate over the given attribute, which is guaranteed to be a container.
HandlerResult handlerResult = HandlerResult::Continue;
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
ArrayRef<Attribute> attrs = arrayAttr.getValue();
unsigned attrSize = attrs.size();
while (index != attrSize)
if ((handlerResult = handleAttrDuringSymbolWalk(op, attrs[index++],
worklist, callback)))
break;
} else {
auto dictAttr = attr.cast<DictionaryAttr>();
ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
unsigned attrSize = attrs.size();
while (index != attrSize)
if ((handlerResult = handleAttrDuringSymbolWalk(
op, attrs[index++].second, worklist, callback)))
break;
}
if (handlerResult == HandlerResult::Interrupt)
return WalkResult::interrupt();
// If we didn't encounter a nested attribute, pop the last item from the
// worklist.
if (handlerResult != HandlerResult::RecurseNestedAttribute)
worklist.pop_back();
}
return WalkResult::advance();
}
/// Walk all of the uses, for any symbol, that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
WalkResult
SymbolTable::walkSymbolUses(Operation *from,
function_ref<WalkResult(SymbolUse)> callback) {
// If from is not a symbol table, check for uses. A symbol table defines a new
// scope, so we can't walk the attributes from the symbol table op.
if (!from->hasTrait<OpTrait::SymbolTable>()) {
if (walkSymbolRefs(from, callback).wasInterrupted())
return WalkResult::interrupt();
}
SmallVector<Region *, 1> worklist;
worklist.reserve(from->getNumRegions());
for (Region &region : from->getRegions())
worklist.push_back(&region);
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();
for (Block &block : *region) {
for (Operation &op : block) {
if (walkSymbolRefs(&op, callback).wasInterrupted())
return WalkResult::interrupt();
// If this op defines a new symbol table scope, we can't traverse. Any
// symbol references nested within 'op' are different semantically.
if (!op.hasTrait<OpTrait::SymbolTable>()) {
for (Region &region : op.getRegions())
worklist.push_back(&region);
}
}
}
}
return WalkResult::advance();
}
/// Walk all of the uses, for any symbol, that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
/// 'from' if it does not also define a symbol table.
WalkResult
SymbolTable::walkSymbolUses(StringRef symbol, Operation *from,
function_ref<WalkResult(SymbolUse)> callback) {
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
return walkSymbolUses(from, [&](SymbolUse symbolUse) {
if (symbolUse.getSymbolRef() != symbolRefAttr)
return WalkResult::advance();
return callback(std::move(symbolUse));
});
}
/// Return if the given symbol has no uses that are nested within the given
/// operation 'from'. This does not traverse into any nested symbol tables,
/// and will also only count uses on 'from' if it does not also define a
/// symbol table.
bool SymbolTable::symbol_use_empty(StringRef symbol, Operation *from) {
SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
// Walk all of the symbol uses looking for a reference to 'symbol'.
auto walkResult = walkSymbolUses(from, [&](SymbolUse symbolUse) {
return symbolUse.getSymbolRef() == symbolRefAttr ? WalkResult::interrupt()
: WalkResult::advance();
});
return !walkResult.wasInterrupted();
}

View File

@ -0,0 +1,29 @@
// RUN: mlir-opt %s -test-symbol-uses -verify-diagnostics
// Symbol references to the module itself don't affect uses of symbols within
// its table.
module attributes {sym.outside_use = @symbol_foo } {
// expected-remark@+1 {{function has 2 uses}}
func @symbol_foo()
// expected-remark@+3 {{function has no uses}}
// expected-remark@+2 {{found use of function : @symbol_foo}}
// expected-remark@+1 {{function contains 2 nested references}}
func @symbol_bar() attributes {sym.use = @symbol_foo} {
// expected-remark@+1 {{found use of function : @symbol_foo}}
"foo.op"() {
non_symbol_attr,
use = [{ nested_symbol = [@symbol_foo]}],
z_other_non_symbol_attr
} : () -> ()
}
// expected-remark@+1 {{function has 1 use}}
func @symbol_baz()
// expected-remark@+1 {{found use of function : @symbol_baz}}
module attributes {test.reference = @symbol_baz} {
"foo.op"() {test.nested_reference = @symbol_baz} : () -> ()
}
}

View File

@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(TestDialect)
add_subdirectory(Transforms)

View File

@ -0,0 +1,8 @@
add_llvm_library(MLIRTestIR
TestSymbolUses.cpp
ADDITIONAL_HEADER_DIRS
)
target_link_libraries(MLIRTestIR
MLIRPass
)

View File

@ -0,0 +1,63 @@
//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// This is a symbol test pass that tests the symbol uselist functionality
/// provided by the symbol table.
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
void runOnModule() override {
auto module = getModule();
for (FuncOp func : module.getOps<FuncOp>()) {
// Test computing uses on a non symboltable op.
unsigned numUses = 0;
SymbolTable::walkSymbolUses(func, [&](SymbolTable::SymbolUse) {
++numUses;
return WalkResult::advance();
});
if (numUses != 0)
func.emitRemark() << "function contains " << numUses
<< " nested references";
// Test the functionality of symbol_use_empty.
if (SymbolTable::symbol_use_empty(func.getName(), module)) {
func.emitRemark() << "function has no uses";
continue;
}
// Test the functionality of walkSymbolUses.
numUses = 0;
SymbolTable::walkSymbolUses(
func.getName(), module, [&](SymbolTable::SymbolUse symbolUse) {
symbolUse.getUser()->emitRemark()
<< "found use of function : " << symbolUse.getSymbolRef();
++numUses;
return WalkResult::advance();
});
func.emitRemark() << "function has " << numUses << " uses";
}
}
};
} // end anonymous namespace
static PassRegistration<SymbolUsesPass> pass("test-symbol-uses",
"Test detection of symbol uses");

View File

@ -42,6 +42,7 @@ set(LIBS
MLIRStandardToLLVM
MLIRTransforms
MLIRTestDialect
MLIRTestIR
MLIRTestPass
MLIRTestTransforms
MLIRSupport