forked from OSchip/llvm-project
[mlir] Remove locking for dialect/operation registration.
Moving forward dialects should only be registered in a thread safe context. This matches the existing usage we have today, but it allows for removing quite a bit of expensive locking from the context. This led to ~.5 a second compile time improvement when running one conversion pass on a very large .mlir file(hundreds of thousands of operations). Differential Revision: https://reviews.llvm.org/D82595
This commit is contained in:
parent
2e2cdd0a52
commit
5d699d18b3
|
@ -258,10 +258,12 @@ private:
|
|||
};
|
||||
/// Registers all dialects and hooks from the global registries with the
|
||||
/// specified MLIRContext.
|
||||
/// Note: This method is not thread-safe.
|
||||
void registerAllDialects(MLIRContext *context);
|
||||
|
||||
/// Utility to register a dialect. Client can register their dialect with the
|
||||
/// global registry by calling registerDialect<MyDialect>();
|
||||
/// Note: This method is not thread-safe.
|
||||
template <typename ConcreteDialect> void registerDialect() {
|
||||
Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
|
||||
[](MLIRContext *ctx) {
|
||||
|
|
|
@ -270,10 +270,6 @@ public:
|
|||
// Other
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// A general purpose mutex to lock access to parts of the context that do not
|
||||
/// have a more specific mutex, e.g. registry operations.
|
||||
llvm::sys::SmartRWMutex<true> contextMutex;
|
||||
|
||||
/// This is a list of dialects that are created referring to this context.
|
||||
/// The MLIRContext owns the objects.
|
||||
std::vector<std::unique_ptr<Dialect>> dialects;
|
||||
|
@ -425,8 +421,6 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
|
|||
|
||||
/// Return information about all registered IR dialects.
|
||||
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
|
||||
// Lock access to the context registry.
|
||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
||||
std::vector<Dialect *> result;
|
||||
result.reserve(impl->dialects.size());
|
||||
for (auto &dialect : impl->dialects)
|
||||
|
@ -437,9 +431,6 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
|
|||
/// Get a registered IR dialect with the given namespace. If none is found,
|
||||
/// then return nullptr.
|
||||
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
|
||||
// Lock access to the context registry.
|
||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
||||
|
||||
// Dialects are sorted by name, so we can use binary search for lookup.
|
||||
auto it = llvm::lower_bound(
|
||||
impl->dialects, name,
|
||||
|
@ -455,9 +446,6 @@ void Dialect::registerDialect(MLIRContext *context) {
|
|||
auto &impl = context->getImpl();
|
||||
std::unique_ptr<Dialect> dialect(this);
|
||||
|
||||
// Lock access to the context registry.
|
||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
||||
|
||||
// Get the correct insertion position sorted by namespace.
|
||||
auto insertPt = llvm::lower_bound(
|
||||
impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
|
||||
|
@ -524,35 +512,26 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
|
|||
/// efficient, typically you should ask the operations about their properties
|
||||
/// directly.
|
||||
std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
|
||||
std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
|
||||
|
||||
{ // Lock access to the context registry.
|
||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
||||
|
||||
// We just have the operations in a non-deterministic hash table order. Dump
|
||||
// into a temporary array, then sort it by operation name to get a stable
|
||||
// ordering.
|
||||
llvm::StringMap<AbstractOperation> ®isteredOps =
|
||||
impl->registeredOperations;
|
||||
|
||||
opsToSort.reserve(registeredOps.size());
|
||||
for (auto &elt : registeredOps)
|
||||
opsToSort.push_back({elt.first(), &elt.second});
|
||||
}
|
||||
|
||||
llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
|
||||
// We just have the operations in a non-deterministic hash table order. Dump
|
||||
// into a temporary array, then sort it by operation name to get a stable
|
||||
// ordering.
|
||||
llvm::StringMap<AbstractOperation> ®isteredOps =
|
||||
impl->registeredOperations;
|
||||
|
||||
std::vector<AbstractOperation *> result;
|
||||
result.reserve(opsToSort.size());
|
||||
for (auto &elt : opsToSort)
|
||||
result.push_back(elt.second);
|
||||
result.reserve(registeredOps.size());
|
||||
for (auto &elt : registeredOps)
|
||||
result.push_back(&elt.second);
|
||||
llvm::array_pod_sort(
|
||||
result.begin(), result.end(),
|
||||
[](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
|
||||
return (*lhs)->name.compare((*rhs)->name);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool MLIRContext::isOperationRegistered(StringRef name) {
|
||||
// Lock access to the context registry.
|
||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
||||
|
||||
return impl->registeredOperations.count(name);
|
||||
}
|
||||
|
||||
|
@ -561,12 +540,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
|||
"op name doesn't start with dialect namespace");
|
||||
assert(&opInfo.dialect == this && "Dialect object mismatch");
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
StringRef opName = opInfo.name;
|
||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
||||
if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
|
||||
llvm::errs() << "error: operation named '" << opName
|
||||
llvm::errs() << "error: operation named '" << opInfo.name
|
||||
<< "' is already registered.\n";
|
||||
abort();
|
||||
}
|
||||
|
@ -574,9 +550,6 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
|||
|
||||
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
||||
auto *newInfo =
|
||||
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
|
||||
AbstractType(std::move(typeInfo));
|
||||
|
@ -586,9 +559,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
|||
|
||||
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
||||
auto *newInfo =
|
||||
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
|
||||
AbstractAttribute(std::move(attrInfo));
|
||||
|
@ -612,9 +582,6 @@ const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
|
|||
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
|
||||
MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Lock access to the context registry.
|
||||
ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
||||
auto it = impl.registeredOperations.find(opName);
|
||||
if (it != impl.registeredOperations.end())
|
||||
return &it->second;
|
||||
|
|
Loading…
Reference in New Issue