forked from OSchip/llvm-project
Move ModuleManager functionality into mlir::SymbolTable.
Note for broken code, the following transformations occurred: ModuleManager::insert(Block::iterator, Operation*) - > SymbolTable::insert(Operation*, Block::iterator) ModuleManager::lookupSymbol -> SymbolTable::lookup ModuleManager::getModule() -> SymbolTable::getOp() ModuleManager::getContext() -> SymbolTable::getOp()->getContext() ModuleManager::* -> SymbolTable::* PiperOrigin-RevId: 283944635
This commit is contained in:
parent
b60799b71b
commit
b8cd0c1486
|
@ -153,7 +153,7 @@ struct PythonMLIRModule {
|
|||
PythonMLIRModule()
|
||||
: mlirContext(),
|
||||
module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))),
|
||||
moduleManager(*module) {}
|
||||
symbolTable(*module) {}
|
||||
|
||||
PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
|
||||
return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
|
||||
|
@ -270,7 +270,7 @@ struct PythonMLIRModule {
|
|||
}
|
||||
|
||||
PythonFunction getNamedFunction(const std::string &name) {
|
||||
return moduleManager.lookupSymbol<FuncOp>(name);
|
||||
return symbolTable.lookup<FuncOp>(name);
|
||||
}
|
||||
|
||||
PythonFunctionContext
|
||||
|
@ -282,7 +282,7 @@ private:
|
|||
mlir::MLIRContext mlirContext;
|
||||
// One single module in a python-exposed MLIRContext for now.
|
||||
mlir::OwningModuleRef module;
|
||||
mlir::ModuleManager moduleManager;
|
||||
mlir::SymbolTable symbolTable;
|
||||
|
||||
// An execution engine and an associated target machine. The latter must
|
||||
// outlive the former since it may be used by the transformation layers.
|
||||
|
@ -692,7 +692,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
|
|||
UnknownLoc::get(&mlirContext), name,
|
||||
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
|
||||
inputAttrs);
|
||||
moduleManager.insert(func);
|
||||
symbolTable.insert(func);
|
||||
return func;
|
||||
}
|
||||
|
||||
|
|
|
@ -118,56 +118,6 @@ public:
|
|||
static void build(Builder *, OperationState &) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module Manager.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A class used to manage the symbols held by a module. This class handles
|
||||
/// ensures that symbols inserted into a module have a unique name, and provides
|
||||
/// efficient named lookup to held symbols.
|
||||
class ModuleManager {
|
||||
public:
|
||||
ModuleManager(ModuleOp module) : module(module), symbolTable(module) {}
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names must never include the @ on them.
|
||||
template <typename T, typename NameTy> T lookupSymbol(NameTy &&name) const {
|
||||
return symbolTable.lookup<T>(name);
|
||||
}
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names must never include the @ on them.
|
||||
template <typename NameTy> Operation *lookupSymbol(NameTy &&name) const {
|
||||
return symbolTable.lookup(name);
|
||||
}
|
||||
|
||||
/// Insert a new symbol into the module, auto-renaming it as necessary.
|
||||
void insert(Operation *op) {
|
||||
symbolTable.insert(op);
|
||||
module.push_back(op);
|
||||
}
|
||||
void insert(Block::iterator insertPt, Operation *op) {
|
||||
symbolTable.insert(op);
|
||||
module.insert(insertPt, op);
|
||||
}
|
||||
|
||||
/// Remove the given symbol from the module symbol table and then erase it.
|
||||
void erase(Operation *op) {
|
||||
symbolTable.erase(op);
|
||||
op->erase();
|
||||
}
|
||||
|
||||
/// Return the internally held module.
|
||||
ModuleOp getModule() const { return module; }
|
||||
|
||||
/// Return the context of the internal module.
|
||||
MLIRContext *getContext() { return module.getContext(); }
|
||||
|
||||
private:
|
||||
ModuleOp module;
|
||||
SymbolTable symbolTable;
|
||||
};
|
||||
|
||||
/// This class acts as an owning reference to a module, and will automatically
|
||||
/// destroy the held module if valid.
|
||||
class OwningModuleRef {
|
||||
|
|
|
@ -23,15 +23,16 @@
|
|||
|
||||
namespace mlir {
|
||||
class Identifier;
|
||||
class MLIRContext;
|
||||
class Operation;
|
||||
|
||||
/// This class allows for representing and managing the symbol table used by
|
||||
/// operations with the 'SymbolTable' trait.
|
||||
/// operations with the 'SymbolTable' trait. Inserting into and erasing from
|
||||
/// this SymbolTable will also insert and erase from the Operation given to it
|
||||
/// at construction.
|
||||
class SymbolTable {
|
||||
public:
|
||||
/// Build a symbol table with the symbols within the given operation.
|
||||
SymbolTable(Operation *op);
|
||||
SymbolTable(Operation *symbolTableOp);
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names never include the @ on them.
|
||||
|
@ -44,15 +45,16 @@ public:
|
|||
void erase(Operation *symbol);
|
||||
|
||||
/// Insert a new symbol into the table, and rename it as necessary to avoid
|
||||
/// collisions.
|
||||
void insert(Operation *symbol);
|
||||
|
||||
/// Returns the context held by this symbol table.
|
||||
MLIRContext *getContext() const { return context; }
|
||||
/// collisions. Also insert at the specified location in the body of the
|
||||
/// associated operation.
|
||||
void insert(Operation *symbol, Block::iterator insertPt = {});
|
||||
|
||||
/// Return the name of the attribute used for symbol names.
|
||||
static StringRef getSymbolAttrName() { return "sym_name"; }
|
||||
|
||||
/// Returns the associated operation.
|
||||
Operation *getOp() const { return symbolTableOp; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Symbol Utilities
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -60,7 +62,7 @@ public:
|
|||
/// Returns the operation registered with the given symbol name with the
|
||||
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
|
||||
/// with the 'OpTrait::SymbolTable' trait.
|
||||
static Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
|
||||
static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
|
||||
|
||||
/// Returns the operation registered with the given symbol name within the
|
||||
/// closest parent operation of, or including, 'from' with the
|
||||
|
@ -118,11 +120,11 @@ public:
|
|||
/// are any unknown operations that may potentially be symbol tables.
|
||||
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
|
||||
|
||||
/// Return if the given symbol is known to have no uses that are nested within
|
||||
/// the given operation 'from'. This does not traverse into any nested symbol
|
||||
/// tables, and will also only count uses on 'from' if it does not also define
|
||||
/// a symbol table. This is because we treat the region as the boundary of
|
||||
/// the symbol table, and not the op itself. This function will also return
|
||||
/// Return if the given symbol is known to have no uses that are nested
|
||||
/// within the given operation 'from'. This does not traverse into any nested
|
||||
/// symbol tables, and will also only count uses on 'from' if it does not also
|
||||
/// define a symbol table. This is because we treat the region as the boundary
|
||||
/// of the symbol table, and not the op itself. This function will also return
|
||||
/// false if there are any unknown operations that may potentially be symbol
|
||||
/// tables. This doesn't necessarily mean that there are no uses, we just
|
||||
/// can't convervatively prove it.
|
||||
|
@ -141,7 +143,7 @@ public:
|
|||
Operation *from);
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
Operation *symbolTableOp;
|
||||
|
||||
/// This is a mapping from a name to the symbol with that name.
|
||||
llvm::StringMap<Operation *> symbolTable;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -155,7 +156,7 @@ namespace {
|
|||
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
|
||||
public:
|
||||
void runOnModule() override {
|
||||
ModuleManager moduleManager(getModule());
|
||||
SymbolTable symbolTable(getModule());
|
||||
bool modified = false;
|
||||
for (auto func : getModule().getOps<FuncOp>()) {
|
||||
// Insert just after the function.
|
||||
|
@ -166,8 +167,8 @@ public:
|
|||
// Create nested module and insert outlinedFunc. The module will
|
||||
// originally get the same name as the function, but may be renamed on
|
||||
// insertion into the parent module.
|
||||
auto kernelModule = createKernelModule(outlinedFunc, moduleManager);
|
||||
moduleManager.insert(insertPt, kernelModule);
|
||||
auto kernelModule = createKernelModule(outlinedFunc, symbolTable);
|
||||
symbolTable.insert(kernelModule, insertPt);
|
||||
|
||||
// Potentially changes signature, pulling in constants.
|
||||
convertToLaunchFuncOp(op, outlinedFunc);
|
||||
|
@ -185,16 +186,15 @@ public:
|
|||
private:
|
||||
// Returns a module containing kernelFunc and all callees (recursive).
|
||||
ModuleOp createKernelModule(FuncOp kernelFunc,
|
||||
const ModuleManager &parentModuleManager) {
|
||||
const SymbolTable &parentSymbolTable) {
|
||||
auto context = getModule().getContext();
|
||||
Builder builder(context);
|
||||
auto kernelModule =
|
||||
ModuleOp::create(builder.getUnknownLoc(), kernelFunc.getName());
|
||||
kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
|
||||
builder.getUnitAttr());
|
||||
ModuleManager moduleManager(kernelModule);
|
||||
|
||||
moduleManager.insert(kernelFunc);
|
||||
SymbolTable symbolTable(kernelModule);
|
||||
symbolTable.insert(kernelFunc);
|
||||
|
||||
llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
|
||||
while (!symbolDefWorklist.empty()) {
|
||||
|
@ -203,13 +203,13 @@ private:
|
|||
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
|
||||
StringRef symbolName =
|
||||
symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
|
||||
if (moduleManager.lookupSymbol(symbolName))
|
||||
if (symbolTable.lookup(symbolName))
|
||||
continue;
|
||||
|
||||
Operation *symbolDefClone =
|
||||
parentModuleManager.lookupSymbol(symbolName)->clone();
|
||||
parentSymbolTable.lookup(symbolName)->clone();
|
||||
symbolDefWorklist.push_back(symbolDefClone);
|
||||
moduleManager.insert(symbolDefClone);
|
||||
symbolTable.insert(symbolDefClone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,23 +31,24 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Build a symbol table with the symbols within the given operation.
|
||||
SymbolTable::SymbolTable(Operation *op) : context(op->getContext()) {
|
||||
assert(op->hasTrait<OpTrait::SymbolTable>() &&
|
||||
SymbolTable::SymbolTable(Operation *symbolTableOp)
|
||||
: symbolTableOp(symbolTableOp) {
|
||||
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
|
||||
"expected operation to have SymbolTable trait");
|
||||
assert(op->getNumRegions() == 1 &&
|
||||
assert(symbolTableOp->getNumRegions() == 1 &&
|
||||
"expected operation to have a single region");
|
||||
assert(has_single_element(symbolTableOp->getRegion(0)) &&
|
||||
"expected operation to have a single block");
|
||||
|
||||
for (auto &block : op->getRegion(0)) {
|
||||
for (auto &op : block) {
|
||||
auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
|
||||
if (!nameAttr)
|
||||
continue;
|
||||
for (auto &op : symbolTableOp->getRegion(0).front()) {
|
||||
auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
|
||||
if (!nameAttr)
|
||||
continue;
|
||||
|
||||
auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
|
||||
(void)inserted;
|
||||
assert(inserted.second &&
|
||||
"expected region to contain uniquely named symbol operations");
|
||||
}
|
||||
auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
|
||||
(void)inserted;
|
||||
assert(inserted.second &&
|
||||
"expected region to contain uniquely named symbol operations");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,18 +62,32 @@ Operation *SymbolTable::lookup(StringRef name) const {
|
|||
void SymbolTable::erase(Operation *symbol) {
|
||||
auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
|
||||
assert(nameAttr && "expected valid 'name' attribute");
|
||||
assert(symbol->getParentOp() == symbolTableOp &&
|
||||
"expected this operation to be inside of the operation with this "
|
||||
"SymbolTable");
|
||||
|
||||
auto it = symbolTable.find(nameAttr.getValue());
|
||||
if (it != symbolTable.end() && it->second == symbol)
|
||||
if (it != symbolTable.end() && it->second == symbol) {
|
||||
symbolTable.erase(it);
|
||||
symbol->erase();
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a new symbol into the table, and rename it as necessary to avoid
|
||||
/// collisions.
|
||||
void SymbolTable::insert(Operation *symbol) {
|
||||
/// Insert a new symbol into the table and associated operation, and rename it
|
||||
/// as necessary to avoid collisions.
|
||||
void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
|
||||
auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
|
||||
assert(nameAttr && "expected valid 'name' attribute");
|
||||
|
||||
auto &body = symbolTableOp->getRegion(0).front();
|
||||
if (insertPt == Block::iterator() || insertPt == body.end())
|
||||
insertPt = Block::iterator(body.getTerminator());
|
||||
|
||||
assert(insertPt->getParentOp() == symbolTableOp &&
|
||||
"expected insertPt to be in the associated module operation");
|
||||
|
||||
body.getOperations().insert(insertPt, symbol);
|
||||
|
||||
// Add this symbol to the symbol table, uniquing the name if a conflict is
|
||||
// detected.
|
||||
if (symbolTable.insert({nameAttr.getValue(), symbol}).second)
|
||||
|
@ -89,7 +104,8 @@ void SymbolTable::insert(Operation *symbol) {
|
|||
nameBuffer += '_';
|
||||
nameBuffer += std::to_string(uniquingCounter++);
|
||||
} while (!symbolTable.insert({nameBuffer, symbol}).second);
|
||||
symbol->setAttr(getSymbolAttrName(), StringAttr::get(nameBuffer, context));
|
||||
symbol->setAttr(getSymbolAttrName(),
|
||||
StringAttr::get(nameBuffer, symbolTableOp->getContext()));
|
||||
}
|
||||
|
||||
/// Returns the operation registered with the given symbol name with the
|
||||
|
@ -136,6 +152,9 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
|
|||
if (op->getNumRegions() != 1)
|
||||
return op->emitOpError()
|
||||
<< "Operations with a 'SymbolTable' must have exactly one region";
|
||||
if (!has_single_element(op->getRegion(0)))
|
||||
return op->emitOpError()
|
||||
<< "Operations with a 'SymbolTable' must have exactly one block";
|
||||
|
||||
// Check that all symbols are uniquely named within child regions.
|
||||
llvm::StringMap<Location> nameToOrigLoc;
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// -----
|
||||
|
||||
func @module_op() {
|
||||
// expected-error@+1 {{expects region #0 to have 0 or 1 blocks}}
|
||||
// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
|
||||
module {
|
||||
^bb1:
|
||||
"module_terminator"() : () -> ()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
// Symbol references to the module itself don't affect uses of symbols within
|
||||
// its table.
|
||||
// expected-remark@below {{symbol_removable function successfully erased}}
|
||||
module attributes {sym.outside_use = @symbol_foo } {
|
||||
// expected-remark@+1 {{function has 2 uses}}
|
||||
func @symbol_foo()
|
||||
|
@ -18,6 +19,9 @@ module attributes {sym.outside_use = @symbol_foo } {
|
|||
} : () -> ()
|
||||
}
|
||||
|
||||
// expected-remark@below {{function has no uses}}
|
||||
func @symbol_removable()
|
||||
|
||||
// expected-remark@+1 {{function has 1 use}}
|
||||
func @symbol_baz()
|
||||
|
||||
|
|
|
@ -218,6 +218,17 @@ func @foo() {
|
|||
|
||||
// -----
|
||||
|
||||
// Test that operation with the SymbolTable Trait fails with too many blocks.
|
||||
// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
|
||||
"test.symbol_scope"() ({
|
||||
^entry:
|
||||
"test.finish" () : () -> ()
|
||||
^other:
|
||||
"test.finish" () : () -> ()
|
||||
}) : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
func @failedMissingOperandSizeAttr(%arg: i32) {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> ()
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "TestDialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
|
@ -22,10 +23,11 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
/// This is a symbol test pass that tests the symbol uselist functionality
|
||||
/// provided by the symbol table.
|
||||
/// provided by the symbol table along with erasing from the symbol table.
|
||||
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
std::vector<FuncOp> ops_to_delete;
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
// Test computing uses on a non symboltable op.
|
||||
|
@ -45,6 +47,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
// Test the functionality of symbolKnownUseEmpty.
|
||||
if (func.symbolKnownUseEmpty(module)) {
|
||||
func.emitRemark() << "function has no uses";
|
||||
if (func.getBody().empty())
|
||||
ops_to_delete.push_back(func);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -58,6 +62,18 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
func.emitRemark() << "function has " << llvm::size(*symbolUses)
|
||||
<< " uses";
|
||||
}
|
||||
|
||||
for (FuncOp func : ops_to_delete) {
|
||||
// In order to test the SymbolTable::erase method, also erase completely
|
||||
// useless functions.
|
||||
SymbolTable table(module);
|
||||
auto func_name = func.getName();
|
||||
assert(table.lookup(func_name) && "expected no unknown operations");
|
||||
table.erase(func);
|
||||
assert(!table.lookup(func_name) &&
|
||||
"expected erased operation to be unknown now");
|
||||
module.emitRemark() << func_name << " function successfully erased";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -92,6 +92,12 @@ def SymbolScopeOp : TEST_Op<"symbol_scope",
|
|||
let regions = (region SizedRegion<1>:$region);
|
||||
}
|
||||
|
||||
def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
|
||||
let summary = "operation which defines a new symbol table without a "
|
||||
"restriction on a terminator";
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Operands
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue