diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index db23174d58ca..8f3b3b0df13a 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -50,11 +50,6 @@ public: /// Return the name of this function, without the @. Identifier getName() { return name; } - /// Swap the name of the given function with this one. The caller must ensure - /// that all existing references to the current name of each function have - /// been properly updated. - void takeName(Function &rhs); - /// Return the type of this function. FunctionType getType() { return type; } @@ -295,6 +290,9 @@ public: void cloneInto(Function *dest, BlockAndValueMapping &mapper); private: + /// Set the name of this function. + void setName(Identifier newName) { name = newName; } + /// The name of the function. Identifier name; @@ -318,6 +316,9 @@ private: void operator=(Function &) = delete; friend struct llvm::ilist_traits; + + // Allow access to 'setName'. + friend class SymbolTable; }; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index b11e074cbab5..14f81d5fdf1c 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -23,18 +23,16 @@ #define MLIR_IR_MODULE_H #include "mlir/IR/Function.h" -#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/SymbolTable.h" #include "llvm/ADT/ilist.h" namespace mlir { -class AffineMap; - class Module { public: - explicit Module(MLIRContext *context); + explicit Module(MLIRContext *context) : symbolTable(context) {} - MLIRContext *getContext() { return context; } + MLIRContext *getContext() { return symbolTable.getContext(); } /// This is the list of functions in the module. using FunctionListType = llvm::iplist; @@ -53,11 +51,15 @@ public: /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. - Function *getNamedFunction(StringRef name); + Function *getNamedFunction(StringRef name) { + return symbolTable.lookup(name); + } /// Look up a function with the specified name, returning null if no such /// name exists. Function names never include the @ on them. - Function *getNamedFunction(Identifier name); + Function *getNamedFunction(Identifier name) { + return symbolTable.lookup(name); + } /// Perform (potentially expensive) checks of invariants, used to detect /// compiler bugs. On error, this reports the error through the MLIRContext @@ -76,17 +78,12 @@ private: return &Module::functions; } - MLIRContext *context; - - /// This is a mapping from a name to the function with that name. - llvm::DenseMap symbolTable; - - /// This is used when name conflicts are detected. - unsigned uniquingCounter = 0; + /// The symbol table used for functions. + SymbolTable symbolTable; /// This is the actual list of functions the module contains. FunctionListType functions; }; } // end namespace mlir -#endif // MLIR_IR_FUNCTION_H +#endif // MLIR_IR_MODULE_H diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h new file mode 100644 index 000000000000..8c30dfe0af39 --- /dev/null +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -0,0 +1,64 @@ +//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_IR_SYMBOLTABLE_H +#define MLIR_IR_SYMBOLTABLE_H + +#include "mlir/IR/Identifier.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir { +class Function; +class MLIRContext; + +/// This class represents the symbol table used by a module for function +/// symbols. +class SymbolTable { +public: + SymbolTable(MLIRContext *ctx) : context(ctx) {} + + /// Look up a symbol with the specified name, returning null if no such + /// name exists. Names never include the @ on them. + Function *lookup(StringRef name) const; + + /// Look up a symbol with the specified name, returning null if no such + /// name exists. Names never include the @ on them. + Function *lookup(Identifier name) const; + + /// Erase the given symbol from the table. + void erase(Function *symbol); + + /// Insert a new symbol into the table, and rename it as necessary to avoid + /// collisions. + void insert(Function *symbol); + + /// Returns the context held by this symbol table. + MLIRContext *getContext() const { return context; } + +private: + MLIRContext *context; + + /// This is a mapping from a name to the function with that name. + llvm::DenseMap symbolTable; + + /// This is used when name conflicts are detected. + unsigned uniquingCounter = 0; +}; + +} // end namespace mlir + +#endif // MLIR_IR_SYMBOLTABLE_H diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 799709d73c19..4e0579538ea1 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -41,14 +41,6 @@ Function::Function(Location location, StringRef name, FunctionType type, MLIRContext *Function::getContext() { return getType().getContext(); } -/// Swap the name of the given function with this one. -void Function::takeName(Function &rhs) { - auto *module = getModule(); - assert(module && module == rhs.getModule() && "expected same parent module"); - std::swap(module->symbolTable[name], module->symbolTable[rhs.getName()]); - std::swap(name, rhs.name); -} - Module *llvm::ilist_traits::getContainingModule() { size_t Offset( size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr)))); @@ -63,25 +55,8 @@ void llvm::ilist_traits::addNodeToList(Function *function) { auto *module = getContainingModule(); function->module = module; - // Add this function to the symbol table of the module, uniquing the name if - // a conflict is detected. - if (!module->symbolTable.insert({function->getName(), function}).second) { - // If a conflict was detected, then the function will not have been added to - // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(function->getName().begin(), - function->getName().end()); - unsigned originalLength = nameBuffer.size(); - - // Iteratively try suffixes until we find one that isn't used. We use a - // module level uniquing counter to avoid N^2 behavior. - do { - nameBuffer.resize(originalLength); - nameBuffer += '_'; - nameBuffer += std::to_string(module->uniquingCounter++); - function->name = Identifier::get(nameBuffer, module->getContext()); - } while ( - !module->symbolTable.insert({function->getName(), function}).second); - } + // Add this function to the symbol table of the module. + module->symbolTable.insert(function); } /// This is a trait method invoked when a Function is removed from a Module. @@ -90,7 +65,7 @@ void llvm::ilist_traits::removeNodeFromList(Function *function) { assert(function->module && "not already in a module!"); // Remove the symbol table entry. - function->module->symbolTable.erase(function->getName()); + function->module->symbolTable.erase(function); function->module = nullptr; } diff --git a/mlir/lib/IR/Module.cpp b/mlir/lib/IR/Module.cpp deleted file mode 100644 index 36b892a349fb..000000000000 --- a/mlir/lib/IR/Module.cpp +++ /dev/null @@ -1,34 +0,0 @@ -//===- Module.cpp - MLIR Module Class -------------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/IR/Module.h" -using namespace mlir; - -Module::Module(MLIRContext *context) : context(context) {} - -/// Look up a function with the specified name, returning null if no such -/// name exists. Function names never include the @ on them. -Function *Module::getNamedFunction(StringRef name) { - return getNamedFunction(Identifier::get(name, context)); -} - -/// Look up a function with the specified name, returning null if no such -/// name exists. Function names never include the @ on them. -Function *Module::getNamedFunction(Identifier name) { - auto it = symbolTable.find(name); - return it != symbolTable.end() ? it->second : nullptr; -} diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp new file mode 100644 index 000000000000..da9ff0fe7d1a --- /dev/null +++ b/mlir/lib/IR/SymbolTable.cpp @@ -0,0 +1,63 @@ +//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Function.h" + +using namespace mlir; + +/// Look up a symbol with the specified name, returning null if no such name +/// exists. Names never include the @ on them. +Function *SymbolTable::lookup(StringRef name) const { + return lookup(Identifier::get(name, context)); +} + +/// Look up a symbol with the specified name, returning null if no such name +/// exists. Names never include the @ on them. +Function *SymbolTable::lookup(Identifier name) const { + return symbolTable.lookup(name); +} + +/// Erase the given symbol from the table. +void SymbolTable::erase(Function *symbol) { + auto it = symbolTable.find(symbol->getName()); + if (it != symbolTable.end() && it->second == symbol) + symbolTable.erase(it); +} + +/// Insert a new symbol into the table, and rename it as necessary to avoid +/// collisions. +void SymbolTable::insert(Function *symbol) { + // Add this symbol to the symbol table, uniquing the name if a conflict is + // detected. + if (symbolTable.insert({symbol->getName(), symbol}).second) + return; + + // If a conflict was detected, then the function will not have been added to + // the symbol table. Try suffixes until we get to a unique name that works. + SmallString<128> nameBuffer(symbol->getName()); + unsigned originalLength = nameBuffer.size(); + + // Iteratively try suffixes until we find one that isn't used. We use a + // module level uniquing counter to avoid N^2 behavior. + do { + nameBuffer.resize(originalLength); + nameBuffer += '_'; + nameBuffer += std::to_string(uniquingCounter++); + symbol->setName(Identifier::get(nameBuffer, context)); + } while (!symbolTable.insert({symbol->getName(), symbol}).second); +}