forked from OSchip/llvm-project
Rename allocator to identifierAllocator and add an identifierMutex to make identifier uniquing thread safe. This also adds a general purpose 'contextMutex' to protect access to the rest of the miscellaneous parts of the MLIRContext, e.g. diagnostics, dialect registration, etc. This is step 5/5 of making the MLIRContext thread-safe.
PiperOrigin-RevId: 238516697
This commit is contained in:
parent
c769f6b985
commit
087e599a3f
|
@ -477,12 +477,21 @@ public:
|
|||
using FusedLocations = DenseSet<FusedLocationStorage *, FusedLocKeyInfo>;
|
||||
FusedLocations fusedLocs;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Identifier uniquing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Identifier allocator and mutex for thread safety.
|
||||
llvm::BumpPtrAllocator identifierAllocator;
|
||||
llvm::sys::SmartRWMutex<true> identifierMutex;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// We put immortal objects into this allocator.
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
/// A general purpose mutex to lock access to parts of the context that do not
|
||||
/// have a more specific mutex, e.g. registry operations, diagnostics, etc.
|
||||
llvm::sys::SmartRWMutex<true> contextMutex;
|
||||
|
||||
/// This is the handler to use to report diagnostics, or null if not
|
||||
/// registered.
|
||||
|
@ -569,7 +578,8 @@ public:
|
|||
sparseElementsAttrs;
|
||||
|
||||
public:
|
||||
MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {}
|
||||
MLIRContextImpl()
|
||||
: filenames(locationAllocator), identifiers(identifierAllocator) {}
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
|
@ -599,11 +609,15 @@ static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
|
|||
/// value that indicates the type of the diagnostic (e.g., Warning, Error).
|
||||
void MLIRContext::registerDiagnosticHandler(
|
||||
const DiagnosticHandlerTy &handler) {
|
||||
// Lock access to the context diagnostic handler.
|
||||
llvm::sys::SmartScopedWriter<true> contextLock(getImpl().contextMutex);
|
||||
getImpl().diagnosticHandler = handler;
|
||||
}
|
||||
|
||||
/// Return the current diagnostic handler, or null if none is present.
|
||||
auto MLIRContext::getDiagnosticHandler() const -> DiagnosticHandlerTy {
|
||||
// Lock access to the context diagnostic handler.
|
||||
llvm::sys::SmartScopedReader<true> contextLock(getImpl().contextMutex);
|
||||
return getImpl().diagnosticHandler;
|
||||
}
|
||||
|
||||
|
@ -625,6 +639,10 @@ void MLIRContext::emitDiagnostic(Location location, const llvm::Twine &message,
|
|||
return;
|
||||
}
|
||||
|
||||
// Lock access to the context so that no other threads emit diagnostics at
|
||||
// the same time.
|
||||
llvm::sys::SmartScopedWriter<true> contextLock(getImpl().contextMutex);
|
||||
|
||||
// If we had a handler registered, emit the diagnostic using it.
|
||||
auto handler = getImpl().diagnosticHandler;
|
||||
if (handler)
|
||||
|
@ -658,6 +676,9 @@ bool MLIRContext::emitError(Location location,
|
|||
|
||||
/// Return information about all registered IR dialects.
|
||||
std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
|
||||
// Lock access to the context registry.
|
||||
llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
|
||||
|
||||
std::vector<Dialect *> result;
|
||||
result.reserve(getImpl().dialects.size());
|
||||
for (auto &dialect : getImpl().dialects)
|
||||
|
@ -668,6 +689,8 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
|
|||
/// Get a registered IR dialect with the given namespace. If none is found,
|
||||
/// then return nullptr.
|
||||
Dialect *MLIRContext::getRegisteredDialect(StringRef name) const {
|
||||
// Lock access to the context registry.
|
||||
llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
|
||||
for (auto &dialect : getImpl().dialects)
|
||||
if (name == dialect->getNamespace())
|
||||
return dialect.get();
|
||||
|
@ -677,22 +700,32 @@ Dialect *MLIRContext::getRegisteredDialect(StringRef name) const {
|
|||
/// Register this dialect object with the specified context. The context
|
||||
/// takes ownership of the heap allocated dialect.
|
||||
void Dialect::registerDialect(MLIRContext *context) {
|
||||
context->getImpl().dialects.push_back(std::unique_ptr<Dialect>(this));
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
|
||||
impl.dialects.push_back(std::unique_ptr<Dialect>(this));
|
||||
}
|
||||
|
||||
/// Return information about all registered operations. This isn't very
|
||||
/// efficient, typically you should ask the operations about their properties
|
||||
/// directly.
|
||||
std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() const {
|
||||
// We just have the operations in a non-deterministic hash table order. Dump
|
||||
// into a temporary array, then sort it by operation name to get a stable
|
||||
// ordering.
|
||||
StringMap<AbstractOperation> ®isteredOps = getImpl().registeredOperations;
|
||||
|
||||
std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
|
||||
opsToSort.reserve(registeredOps.size());
|
||||
for (auto &elt : registeredOps)
|
||||
opsToSort.push_back({elt.first(), &elt.second});
|
||||
|
||||
{ // Lock access to the context registry.
|
||||
llvm::sys::SmartScopedReader<true> registryLock(getImpl().contextMutex);
|
||||
|
||||
// We just have the operations in a non-deterministic hash table order. Dump
|
||||
// into a temporary array, then sort it by operation name to get a stable
|
||||
// ordering.
|
||||
StringMap<AbstractOperation> ®isteredOps =
|
||||
getImpl().registeredOperations;
|
||||
|
||||
opsToSort.reserve(registeredOps.size());
|
||||
for (auto &elt : registeredOps)
|
||||
opsToSort.push_back({elt.first(), &elt.second});
|
||||
}
|
||||
|
||||
llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
|
||||
|
||||
|
@ -707,8 +740,10 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
|||
assert((namePrefix.empty() || (opInfo.name.split('.').first == namePrefix)) &&
|
||||
"op name doesn't start with dialect prefix");
|
||||
assert(&opInfo.dialect == this && "Dialect object mismatch");
|
||||
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
|
||||
if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) {
|
||||
llvm::errs() << "error: ops named '" << opInfo.name
|
||||
<< "' is already registered.\n";
|
||||
|
@ -719,6 +754,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
|||
/// Register a dialect-specific type with the current context.
|
||||
void Dialect::addType(const TypeID *const typeID) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
|
||||
if (!impl.registeredTypes.insert({typeID, this}).second) {
|
||||
llvm::errs() << "error: type already registered.\n";
|
||||
abort();
|
||||
|
@ -730,6 +768,9 @@ void Dialect::addType(const TypeID *const typeID) {
|
|||
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
|
||||
MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
llvm::sys::SmartScopedReader<true> registryLock(impl.contextMutex);
|
||||
auto it = impl.registeredOperations.find(opName);
|
||||
if (it != impl.registeredOperations.end())
|
||||
return &it->second;
|
||||
|
@ -747,6 +788,16 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) {
|
|||
"Cannot create an identifier with a nul character");
|
||||
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
{ // Check for an existing identifier in read-only mode.
|
||||
llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
|
||||
auto it = impl.identifiers.find(str);
|
||||
if (it != impl.identifiers.end())
|
||||
return Identifier(it->getKeyData());
|
||||
}
|
||||
|
||||
// Aquire a writer-lock so that we can safely create the new instance.
|
||||
llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
|
||||
auto it = impl.identifiers.insert({str, char()}).first;
|
||||
return Identifier(it->getKeyData());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue