Rename allocator to identifierAllocator and add an identifierMutex to make identifier uniquing thread safe. This also adds a general purpose 'contextMutex' to protect access to the rest of the miscellaneous parts of the MLIRContext, e.g. diagnostics, dialect registration, etc. This is step 5/5 of making the MLIRContext thread-safe.

PiperOrigin-RevId: 238516697
This commit is contained in:
River Riddle 2019-03-14 14:14:14 -07:00 committed by jpienaar
parent c769f6b985
commit 087e599a3f
1 changed files with 64 additions and 13 deletions

View File

@ -477,12 +477,21 @@ public:
using FusedLocations = DenseSet<FusedLocationStorage *, FusedLocKeyInfo>;
FusedLocations fusedLocs;
//===--------------------------------------------------------------------===//
// Identifier uniquing
//===--------------------------------------------------------------------===//
// Identifier allocator and mutex for thread safety.
llvm::BumpPtrAllocator identifierAllocator;
llvm::sys::SmartRWMutex<true> identifierMutex;
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
/// We put immortal objects into this allocator.
llvm::BumpPtrAllocator allocator;
/// A general purpose mutex to lock access to parts of the context that do not
/// have a more specific mutex, e.g. registry operations, diagnostics, etc.
llvm::sys::SmartRWMutex<true> contextMutex;
/// This is the handler to use to report diagnostics, or null if not
/// registered.
@ -569,7 +578,8 @@ public:
sparseElementsAttrs;
public:
MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {}
MLIRContextImpl()
: filenames(locationAllocator), identifiers(identifierAllocator) {}
};
} // end namespace mlir
@ -599,11 +609,15 @@ static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
/// value that indicates the type of the diagnostic (e.g., Warning, Error).
void MLIRContext::registerDiagnosticHandler(
const DiagnosticHandlerTy &handler) {
// Lock access to the context diagnostic handler.
llvm::sys::SmartScopedWriter<true> contextLock(getImpl().contextMutex);
getImpl().diagnosticHandler = handler;
}
/// Return the current diagnostic handler, or null if none is present.
auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy {
// Lock access to the context diagnostic handler.
llvm::sys::SmartScopedReader<true> contextLock(getImpl().contextMutex);
return getImpl().diagnosticHandler;
}
@ -625,6 +639,10 @@ void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message,
return;
}
// Lock access to the context so that no other threads emit diagnostics at
// the same time.
llvm::sys::SmartScopedWriter<true> contextLock(getImpl().contextMutex);
// If we had a handler registered, emit the diagnostic using it.
auto handler = getImpl().diagnosticHandler;
if (handler)
@ -658,6 +676,9 @@ bool MLIRContext::emitError(Location location,
/// Return information about all registered IR dialects.
std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
// Lock access to the context registry.
llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
std::vector<Dialect *> result;
result.reserve(getImpl().dialects.size());
for (auto &dialect : getImpl().dialects)
@ -668,6 +689,8 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
/// Get a registered IR dialect with the given namespace. If none is found,
/// then return nullptr.
Dialect *MLIRContext::getRegisteredDialect(StringRef name) const {
// Lock access to the context registry.
llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
for (auto &dialect : getImpl().dialects)
if (name == dialect->getNamespace())
return dialect.get();
@ -677,22 +700,32 @@ Dialect *MLIRContext::getRegisteredDialect(StringRef name) const {
/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void Dialect::registerDialect(MLIRContext *context) {
context->getImpl().dialects.push_back(std::unique_ptr<Dialect>(this));
auto &impl = context->getImpl();
// Lock access to the context registry.
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
impl.dialects.push_back(std::unique_ptr<Dialect>(this));
}
/// Return information about all registered operations. This isn't very
/// efficient, typically you should ask the operations about their properties
/// directly.
std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() const {
// We just have the operations in a non-deterministic hash table order. Dump
// into a temporary array, then sort it by operation name to get a stable
// ordering.
StringMap<AbstractOperation> &registeredOps = getImpl().registeredOperations;
std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
opsToSort.reserve(registeredOps.size());
for (auto &elt : registeredOps)
opsToSort.push_back({elt.first(), &elt.second});
{ // Lock access to the context registry.
llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
// We just have the operations in a non-deterministic hash table order. Dump
// into a temporary array, then sort it by operation name to get a stable
// ordering.
StringMap<AbstractOperation> &registeredOps =
getImpl().registeredOperations;
opsToSort.reserve(registeredOps.size());
for (auto &elt : registeredOps)
opsToSort.push_back({elt.first(), &elt.second});
}
llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
@ -707,8 +740,10 @@ void Dialect::addOperation(AbstractOperation opInfo) {
assert((namePrefix.empty() || (opInfo.name.split('.').first == namePrefix)) &&
"op name doesn't start with dialect prefix");
assert(&opInfo.dialect == this && "Dialect object mismatch");
auto &impl = context->getImpl();
// Lock access to the context registry.
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
llvm::errs() << "error: ops named '" << opInfo.name
<< "' is already registered.\n";
@ -719,6 +754,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
/// Register a dialect-specific type with the current context.
void Dialect::addType(const TypeID *const typeID) {
auto &impl = context->getImpl();
// Lock access to the context registry.
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
if (!impl.registeredTypes.insert({typeID, this}).second) {
llvm::errs() << "error: type already registered.\n";
abort();
@ -730,6 +768,9 @@ void Dialect::addType(const TypeID *const typeID) {
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
MLIRContext *context) {
auto &impl = context->getImpl();
// Lock access to the context registry.
llvm::sys::SmartScopedReader<true> registryLock(impl.contextMutex);
auto it = impl.registeredOperations.find(opName);
if (it != impl.registeredOperations.end())
return &it->second;
@ -747,6 +788,16 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) {
"Cannot create an identifier with a nul character");
auto &impl = context->getImpl();
{ // Check for an existing identifier in read-only mode.
llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.find(str);
if (it != impl.identifiers.end())
return Identifier(it->getKeyData());
}
// Aquire a writer-lock so that we can safely create the new instance.
llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.insert({str, char()}).first;
return Identifier(it->getKeyData());
}