Move ModuleManager functionality into mlir::SymbolTable.

Note for broken code, the following transformations occurred:
ModuleManager::insert(Block::iterator, Operation*) - > SymbolTable::insert(Operation*, Block::iterator)
ModuleManager::lookupSymbol -> SymbolTable::lookup
ModuleManager::getModule() -> SymbolTable::getOp()
ModuleManager::getContext() -> SymbolTable::getOp()->getContext()
ModuleManager::* -> SymbolTable::*
PiperOrigin-RevId: 283944635
This commit is contained in:
Tres Popp 2019-12-05 03:56:18 -08:00 committed by A. Unique TensorFlower
parent b60799b71b
commit b8cd0c1486
10 changed files with 107 additions and 99 deletions

View File

@ -153,7 +153,7 @@ struct PythonMLIRModule {
PythonMLIRModule()
: mlirContext(),
module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))),
moduleManager(*module) {}
symbolTable(*module) {}
PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
@ -270,7 +270,7 @@ struct PythonMLIRModule {
}
PythonFunction getNamedFunction(const std::string &name) {
return moduleManager.lookupSymbol<FuncOp>(name);
return symbolTable.lookup<FuncOp>(name);
}
PythonFunctionContext
@ -282,7 +282,7 @@ private:
mlir::MLIRContext mlirContext;
// One single module in a python-exposed MLIRContext for now.
mlir::OwningModuleRef module;
mlir::ModuleManager moduleManager;
mlir::SymbolTable symbolTable;
// An execution engine and an associated target machine. The latter must
// outlive the former since it may be used by the transformation layers.
@ -692,7 +692,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
UnknownLoc::get(&mlirContext), name,
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
inputAttrs);
moduleManager.insert(func);
symbolTable.insert(func);
return func;
}

View File

@ -118,56 +118,6 @@ public:
static void build(Builder *, OperationState &) {}
};
//===----------------------------------------------------------------------===//
// Module Manager.
//===----------------------------------------------------------------------===//
/// A class used to manage the symbols held by a module. This class handles
/// ensures that symbols inserted into a module have a unique name, and provides
/// efficient named lookup to held symbols.
class ModuleManager {
public:
ModuleManager(ModuleOp module) : module(module), symbolTable(module) {}
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
return symbolTable.lookup<T>(name);
}
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
template <typename NameTy> Operation *lookupSymbol(NameTy &&name) const {
return symbolTable.lookup(name);
}
/// Insert a new symbol into the module, auto-renaming it as necessary.
void insert(Operation *op) {
symbolTable.insert(op);
module.push_back(op);
}
void insert(Block::iterator insertPt, Operation *op) {
symbolTable.insert(op);
module.insert(insertPt, op);
}
/// Remove the given symbol from the module symbol table and then erase it.
void erase(Operation *op) {
symbolTable.erase(op);
op->erase();
}
/// Return the internally held module.
ModuleOp getModule() const { return module; }
/// Return the context of the internal module.
MLIRContext *getContext() { return module.getContext(); }
private:
ModuleOp module;
SymbolTable symbolTable;
};
/// This class acts as an owning reference to a module, and will automatically
/// destroy the held module if valid.
class OwningModuleRef {

View File

@ -23,15 +23,16 @@
namespace mlir {
class Identifier;
class MLIRContext;
class Operation;
/// This class allows for representing and managing the symbol table used by
/// operations with the 'SymbolTable' trait.
/// operations with the 'SymbolTable' trait. Inserting into and erasing from
/// this SymbolTable will also insert and erase from the Operation given to it
/// at construction.
class SymbolTable {
public:
/// Build a symbol table with the symbols within the given operation.
SymbolTable(Operation *op);
SymbolTable(Operation *symbolTableOp);
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
@ -44,15 +45,16 @@ public:
void erase(Operation *symbol);
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void insert(Operation *symbol);
/// Returns the context held by this symbol table.
MLIRContext *getContext() const { return context; }
/// collisions. Also insert at the specified location in the body of the
/// associated operation.
void insert(Operation *symbol, Block::iterator insertPt = {});
/// Return the name of the attribute used for symbol names.
static StringRef getSymbolAttrName() { return "sym_name"; }
/// Returns the associated operation.
Operation *getOp() const { return symbolTableOp; }
//===--------------------------------------------------------------------===//
// Symbol Utilities
//===--------------------------------------------------------------------===//
@ -60,7 +62,7 @@ public:
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait.
static Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
/// Returns the operation registered with the given symbol name within the
/// closest parent operation of, or including, 'from' with the
@ -118,11 +120,11 @@ public:
/// are any unknown operations that may potentially be symbol tables.
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
/// Return if the given symbol is known to have no uses that are nested within
/// the given operation 'from'. This does not traverse into any nested symbol
/// tables, and will also only count uses on 'from' if it does not also define
/// a symbol table. This is because we treat the region as the boundary of
/// the symbol table, and not the op itself. This function will also return
/// Return if the given symbol is known to have no uses that are nested
/// within the given operation 'from'. This does not traverse into any nested
/// symbol tables, and will also only count uses on 'from' if it does not also
/// define a symbol table. This is because we treat the region as the boundary
/// of the symbol table, and not the op itself. This function will also return
/// false if there are any unknown operations that may potentially be symbol
/// tables. This doesn't necessarily mean that there are no uses, we just
/// can't convervatively prove it.
@ -141,7 +143,7 @@ public:
Operation *from);
private:
MLIRContext *context;
Operation *symbolTableOp;
/// This is a mapping from a name to the symbol with that name.
llvm::StringMap<Operation *> symbolTable;

View File

@ -24,6 +24,7 @@
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
@ -155,7 +156,7 @@ namespace {
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
ModuleManager moduleManager(getModule());
SymbolTable symbolTable(getModule());
bool modified = false;
for (auto func : getModule().getOps<FuncOp>()) {
// Insert just after the function.
@ -166,8 +167,8 @@ public:
// Create nested module and insert outlinedFunc. The module will
// originally get the same name as the function, but may be renamed on
// insertion into the parent module.
auto kernelModule = createKernelModule(outlinedFunc, moduleManager);
moduleManager.insert(insertPt, kernelModule);
auto kernelModule = createKernelModule(outlinedFunc, symbolTable);
symbolTable.insert(kernelModule, insertPt);
// Potentially changes signature, pulling in constants.
convertToLaunchFuncOp(op, outlinedFunc);
@ -185,16 +186,15 @@ public:
private:
// Returns a module containing kernelFunc and all callees (recursive).
ModuleOp createKernelModule(FuncOp kernelFunc,
const ModuleManager &parentModuleManager) {
const SymbolTable &parentSymbolTable) {
auto context = getModule().getContext();
Builder builder(context);
auto kernelModule =
ModuleOp::create(builder.getUnknownLoc(), kernelFunc.getName());
kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
builder.getUnitAttr());
ModuleManager moduleManager(kernelModule);
moduleManager.insert(kernelFunc);
SymbolTable symbolTable(kernelModule);
symbolTable.insert(kernelFunc);
llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
while (!symbolDefWorklist.empty()) {
@ -203,13 +203,13 @@ private:
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
StringRef symbolName =
symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
if (moduleManager.lookupSymbol(symbolName))
if (symbolTable.lookup(symbolName))
continue;
Operation *symbolDefClone =
parentModuleManager.lookupSymbol(symbolName)->clone();
parentSymbolTable.lookup(symbolName)->clone();
symbolDefWorklist.push_back(symbolDefClone);
moduleManager.insert(symbolDefClone);
symbolTable.insert(symbolDefClone);
}
}
}

View File

@ -31,23 +31,24 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) {
//===----------------------------------------------------------------------===//
/// Build a symbol table with the symbols within the given operation.
SymbolTable::SymbolTable(Operation *op) : context(op->getContext()) {
assert(op->hasTrait<OpTrait::SymbolTable>() &&
SymbolTable::SymbolTable(Operation *symbolTableOp)
: symbolTableOp(symbolTableOp) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
"expected operation to have SymbolTable trait");
assert(op->getNumRegions() == 1 &&
assert(symbolTableOp->getNumRegions() == 1 &&
"expected operation to have a single region");
assert(has_single_element(symbolTableOp->getRegion(0)) &&
"expected operation to have a single block");
for (auto &block : op->getRegion(0)) {
for (auto &op : block) {
auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
if (!nameAttr)
continue;
for (auto &op : symbolTableOp->getRegion(0).front()) {
auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
if (!nameAttr)
continue;
auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
(void)inserted;
assert(inserted.second &&
"expected region to contain uniquely named symbol operations");
}
auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
(void)inserted;
assert(inserted.second &&
"expected region to contain uniquely named symbol operations");
}
}
@ -61,18 +62,32 @@ Operation *SymbolTable::lookup(StringRef name) const {
void SymbolTable::erase(Operation *symbol) {
auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
assert(nameAttr && "expected valid 'name' attribute");
assert(symbol->getParentOp() == symbolTableOp &&
"expected this operation to be inside of the operation with this "
"SymbolTable");
auto it = symbolTable.find(nameAttr.getValue());
if (it != symbolTable.end() && it->second == symbol)
if (it != symbolTable.end() && it->second == symbol) {
symbolTable.erase(it);
symbol->erase();
}
}
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void SymbolTable::insert(Operation *symbol) {
/// Insert a new symbol into the table and associated operation, and rename it
/// as necessary to avoid collisions.
void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
assert(nameAttr && "expected valid 'name' attribute");
auto &body = symbolTableOp->getRegion(0).front();
if (insertPt == Block::iterator() || insertPt == body.end())
insertPt = Block::iterator(body.getTerminator());
assert(insertPt->getParentOp() == symbolTableOp &&
"expected insertPt to be in the associated module operation");
body.getOperations().insert(insertPt, symbol);
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
if (symbolTable.insert({nameAttr.getValue(), symbol}).second)
@ -89,7 +104,8 @@ void SymbolTable::insert(Operation *symbol) {
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
} while (!symbolTable.insert({nameBuffer, symbol}).second);
symbol->setAttr(getSymbolAttrName(), StringAttr::get(nameBuffer, context));
symbol->setAttr(getSymbolAttrName(),
StringAttr::get(nameBuffer, symbolTableOp->getContext()));
}
/// Returns the operation registered with the given symbol name with the
@ -136,6 +152,9 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitOpError()
<< "Operations with a 'SymbolTable' must have exactly one region";
if (!has_single_element(op->getRegion(0)))
return op->emitOpError()
<< "Operations with a 'SymbolTable' must have exactly one block";
// Check that all symbols are uniquely named within child regions.
llvm::StringMap<Location> nameToOrigLoc;

View File

@ -3,7 +3,7 @@
// -----
func @module_op() {
// expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
module {
^bb1:
"module_terminator"() : () -> ()

View File

@ -2,6 +2,7 @@
// Symbol references to the module itself don't affect uses of symbols within
// its table.
// expected-remark@below {{symbol_removable function successfully erased}}
module attributes {sym.outside_use = @symbol_foo } {
// expected-remark@+1 {{function has 2 uses}}
func @symbol_foo()
@ -18,6 +19,9 @@ module attributes {sym.outside_use = @symbol_foo } {
} : () -> ()
}
// expected-remark@below {{function has no uses}}
func @symbol_removable()
// expected-remark@+1 {{function has 1 use}}
func @symbol_baz()

View File

@ -218,6 +218,17 @@ func @foo() {
// -----
// Test that operation with the SymbolTable Trait fails with too many blocks.
// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
"test.symbol_scope"() ({
^entry:
"test.finish" () : () -> ()
^other:
"test.finish" () : () -> ()
}) : () -> ()
// -----
func @failedMissingOperandSizeAttr(%arg: i32) {
// expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> ()

View File

@ -15,6 +15,7 @@
// limitations under the License.
// =============================================================================
#include "TestDialect.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
@ -22,10 +23,11 @@ using namespace mlir;
namespace {
/// This is a symbol test pass that tests the symbol uselist functionality
/// provided by the symbol table.
/// provided by the symbol table along with erasing from the symbol table.
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
void runOnModule() override {
auto module = getModule();
std::vector<FuncOp> ops_to_delete;
for (FuncOp func : module.getOps<FuncOp>()) {
// Test computing uses on a non symboltable op.
@ -45,6 +47,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
// Test the functionality of symbolKnownUseEmpty.
if (func.symbolKnownUseEmpty(module)) {
func.emitRemark() << "function has no uses";
if (func.getBody().empty())
ops_to_delete.push_back(func);
continue;
}
@ -58,6 +62,18 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
func.emitRemark() << "function has " << llvm::size(*symbolUses)
<< " uses";
}
for (FuncOp func : ops_to_delete) {
// In order to test the SymbolTable::erase method, also erase completely
// useless functions.
SymbolTable table(module);
auto func_name = func.getName();
assert(table.lookup(func_name) && "expected no unknown operations");
table.erase(func);
assert(!table.lookup(func_name) &&
"expected erased operation to be unknown now");
module.emitRemark() << func_name << " function successfully erased";
}
}
};

View File

@ -92,6 +92,12 @@ def SymbolScopeOp : TEST_Op<"symbol_scope",
let regions = (region SizedRegion<1>:$region);
}
def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
let summary = "operation which defines a new symbol table without a "
"restriction on a terminator";
let regions = (region SizedRegion<1>:$region);
}
//===----------------------------------------------------------------------===//
// Test Operands
//===----------------------------------------------------------------------===//