Extract the function symbol table functionality, i.e. mapping and name uniquing, out of Module and into a new class SymbolTable. As modules become operations it is necessary to extract out this functionality that cannot be represented with a generic operation.

PiperOrigin-RevId: 254041734
This commit is contained in:
River Riddle 2019-06-19 11:55:27 -07:00 committed by Mehdi Amini
parent fd99b6ce97
commit 927b7074a8
6 changed files with 148 additions and 82 deletions

View File

@ -50,11 +50,6 @@ public:
/// Return the name of this function, without the @.
Identifier getName() { return name; }
/// Swap the name of the given function with this one. The caller must ensure
/// that all existing references to the current name of each function have
/// been properly updated.
void takeName(Function &rhs);
/// Return the type of this function.
FunctionType getType() { return type; }
@ -295,6 +290,9 @@ public:
void cloneInto(Function *dest, BlockAndValueMapping &mapper);
private:
/// Set the name of this function.
void setName(Identifier newName) { name = newName; }
/// The name of the function.
Identifier name;
@ -318,6 +316,9 @@ private:
void operator=(Function &) = delete;
friend struct llvm::ilist_traits<Function>;
// Allow access to 'setName'.
friend class SymbolTable;
};
//===--------------------------------------------------------------------===//

View File

@ -23,18 +23,16 @@
#define MLIR_IR_MODULE_H
#include "mlir/IR/Function.h"
#include "llvm/ADT/DenseMap.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ilist.h"
namespace mlir {
class AffineMap;
class Module {
public:
explicit Module(MLIRContext *context);
explicit Module(MLIRContext *context) : symbolTable(context) {}
MLIRContext *getContext() { return context; }
MLIRContext *getContext() { return symbolTable.getContext(); }
/// This is the list of functions in the module.
using FunctionListType = llvm::iplist<Function>;
@ -53,11 +51,15 @@ public:
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them.
Function *getNamedFunction(StringRef name);
Function *getNamedFunction(StringRef name) {
return symbolTable.lookup(name);
}
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them.
Function *getNamedFunction(Identifier name);
Function *getNamedFunction(Identifier name) {
return symbolTable.lookup(name);
}
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this reports the error through the MLIRContext
@ -76,17 +78,12 @@ private:
return &Module::functions;
}
MLIRContext *context;
/// This is a mapping from a name to the function with that name.
llvm::DenseMap<Identifier, Function *> symbolTable;
/// This is used when name conflicts are detected.
unsigned uniquingCounter = 0;
/// The symbol table used for functions.
SymbolTable symbolTable;
/// This is the actual list of functions the module contains.
FunctionListType functions;
};
} // end namespace mlir
#endif // MLIR_IR_FUNCTION_H
#endif // MLIR_IR_MODULE_H

View File

@ -0,0 +1,64 @@
//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_SYMBOLTABLE_H
#define MLIR_IR_SYMBOLTABLE_H
#include "mlir/IR/Identifier.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
class Function;
class MLIRContext;
/// This class represents the symbol table used by a module for function
/// symbols.
class SymbolTable {
public:
SymbolTable(MLIRContext *ctx) : context(ctx) {}
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
Function *lookup(StringRef name) const;
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
Function *lookup(Identifier name) const;
/// Erase the given symbol from the table.
void erase(Function *symbol);
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void insert(Function *symbol);
/// Returns the context held by this symbol table.
MLIRContext *getContext() const { return context; }
private:
MLIRContext *context;
/// This is a mapping from a name to the function with that name.
llvm::DenseMap<Identifier, Function *> symbolTable;
/// This is used when name conflicts are detected.
unsigned uniquingCounter = 0;
};
} // end namespace mlir
#endif // MLIR_IR_SYMBOLTABLE_H

View File

@ -41,14 +41,6 @@ Function::Function(Location location, StringRef name, FunctionType type,
MLIRContext *Function::getContext() { return getType().getContext(); }
/// Swap the name of the given function with this one.
void Function::takeName(Function &rhs) {
auto *module = getModule();
assert(module && module == rhs.getModule() && "expected same parent module");
std::swap(module->symbolTable[name], module->symbolTable[rhs.getName()]);
std::swap(name, rhs.name);
}
Module *llvm::ilist_traits<Function>::getContainingModule() {
size_t Offset(
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
@ -63,25 +55,8 @@ void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
auto *module = getContainingModule();
function->module = module;
// Add this function to the symbol table of the module, uniquing the name if
// a conflict is detected.
if (!module->symbolTable.insert({function->getName(), function}).second) {
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
SmallString<128> nameBuffer(function->getName().begin(),
function->getName().end());
unsigned originalLength = nameBuffer.size();
// Iteratively try suffixes until we find one that isn't used. We use a
// module level uniquing counter to avoid N^2 behavior.
do {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(module->uniquingCounter++);
function->name = Identifier::get(nameBuffer, module->getContext());
} while (
!module->symbolTable.insert({function->getName(), function}).second);
}
// Add this function to the symbol table of the module.
module->symbolTable.insert(function);
}
/// This is a trait method invoked when a Function is removed from a Module.
@ -90,7 +65,7 @@ void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
assert(function->module && "not already in a module!");
// Remove the symbol table entry.
function->module->symbolTable.erase(function->getName());
function->module->symbolTable.erase(function);
function->module = nullptr;
}

View File

@ -1,34 +0,0 @@
//===- Module.cpp - MLIR Module Class -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Module.h"
using namespace mlir;
Module::Module(MLIRContext *context) : context(context) {}
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them.
Function *Module::getNamedFunction(StringRef name) {
return getNamedFunction(Identifier::get(name, context));
}
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them.
Function *Module::getNamedFunction(Identifier name) {
auto it = symbolTable.find(name);
return it != symbolTable.end() ? it->second : nullptr;
}

View File

@ -0,0 +1,63 @@
//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Function.h"
using namespace mlir;
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
Function *SymbolTable::lookup(StringRef name) const {
return lookup(Identifier::get(name, context));
}
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
Function *SymbolTable::lookup(Identifier name) const {
return symbolTable.lookup(name);
}
/// Erase the given symbol from the table.
void SymbolTable::erase(Function *symbol) {
auto it = symbolTable.find(symbol->getName());
if (it != symbolTable.end() && it->second == symbol)
symbolTable.erase(it);
}
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void SymbolTable::insert(Function *symbol) {
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
if (symbolTable.insert({symbol->getName(), symbol}).second)
return;
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
SmallString<128> nameBuffer(symbol->getName());
unsigned originalLength = nameBuffer.size();
// Iteratively try suffixes until we find one that isn't used. We use a
// module level uniquing counter to avoid N^2 behavior.
do {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
symbol->setName(Identifier::get(nameBuffer, context));
} while (!symbolTable.insert({symbol->getName(), symbol}).second);
}