forked from OSchip/llvm-project
[mlir] Remove locking for dialect/operation registration.
Moving forward dialects should only be registered in a thread safe context. This matches the existing usage we have today, but it allows for removing quite a bit of expensive locking from the context. This led to ~.5 a second compile time improvement when running one conversion pass on a very large .mlir file(hundreds of thousands of operations). Differential Revision: https://reviews.llvm.org/D82595
This commit is contained in:
parent
2e2cdd0a52
commit
5d699d18b3
|
@ -258,10 +258,12 @@ private:
|
||||||
};
|
};
|
||||||
/// Registers all dialects and hooks from the global registries with the
|
/// Registers all dialects and hooks from the global registries with the
|
||||||
/// specified MLIRContext.
|
/// specified MLIRContext.
|
||||||
|
/// Note: This method is not thread-safe.
|
||||||
void registerAllDialects(MLIRContext *context);
|
void registerAllDialects(MLIRContext *context);
|
||||||
|
|
||||||
/// Utility to register a dialect. Client can register their dialect with the
|
/// Utility to register a dialect. Client can register their dialect with the
|
||||||
/// global registry by calling registerDialect<MyDialect>();
|
/// global registry by calling registerDialect<MyDialect>();
|
||||||
|
/// Note: This method is not thread-safe.
|
||||||
template <typename ConcreteDialect> void registerDialect() {
|
template <typename ConcreteDialect> void registerDialect() {
|
||||||
Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
|
Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
|
||||||
[](MLIRContext *ctx) {
|
[](MLIRContext *ctx) {
|
||||||
|
|
|
@ -270,10 +270,6 @@ public:
|
||||||
// Other
|
// Other
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// A general purpose mutex to lock access to parts of the context that do not
|
|
||||||
/// have a more specific mutex, e.g. registry operations.
|
|
||||||
llvm::sys::SmartRWMutex<true> contextMutex;
|
|
||||||
|
|
||||||
/// This is a list of dialects that are created referring to this context.
|
/// This is a list of dialects that are created referring to this context.
|
||||||
/// The MLIRContext owns the objects.
|
/// The MLIRContext owns the objects.
|
||||||
std::vector<std::unique_ptr<Dialect>> dialects;
|
std::vector<std::unique_ptr<Dialect>> dialects;
|
||||||
|
@ -425,8 +421,6 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
|
||||||
|
|
||||||
/// Return information about all registered IR dialects.
|
/// Return information about all registered IR dialects.
|
||||||
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
|
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
|
||||||
std::vector<Dialect *> result;
|
std::vector<Dialect *> result;
|
||||||
result.reserve(impl->dialects.size());
|
result.reserve(impl->dialects.size());
|
||||||
for (auto &dialect : impl->dialects)
|
for (auto &dialect : impl->dialects)
|
||||||
|
@ -437,9 +431,6 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
|
||||||
/// Get a registered IR dialect with the given namespace. If none is found,
|
/// Get a registered IR dialect with the given namespace. If none is found,
|
||||||
/// then return nullptr.
|
/// then return nullptr.
|
||||||
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
|
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
|
||||||
|
|
||||||
// Dialects are sorted by name, so we can use binary search for lookup.
|
// Dialects are sorted by name, so we can use binary search for lookup.
|
||||||
auto it = llvm::lower_bound(
|
auto it = llvm::lower_bound(
|
||||||
impl->dialects, name,
|
impl->dialects, name,
|
||||||
|
@ -455,9 +446,6 @@ void Dialect::registerDialect(MLIRContext *context) {
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
std::unique_ptr<Dialect> dialect(this);
|
std::unique_ptr<Dialect> dialect(this);
|
||||||
|
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
|
||||||
|
|
||||||
// Get the correct insertion position sorted by namespace.
|
// Get the correct insertion position sorted by namespace.
|
||||||
auto insertPt = llvm::lower_bound(
|
auto insertPt = llvm::lower_bound(
|
||||||
impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
|
impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
|
||||||
|
@ -524,35 +512,26 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
|
||||||
/// efficient, typically you should ask the operations about their properties
|
/// efficient, typically you should ask the operations about their properties
|
||||||
/// directly.
|
/// directly.
|
||||||
std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
|
std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
|
||||||
std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
|
|
||||||
|
|
||||||
{ // Lock access to the context registry.
|
|
||||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
|
||||||
|
|
||||||
// We just have the operations in a non-deterministic hash table order. Dump
|
// We just have the operations in a non-deterministic hash table order. Dump
|
||||||
// into a temporary array, then sort it by operation name to get a stable
|
// into a temporary array, then sort it by operation name to get a stable
|
||||||
// ordering.
|
// ordering.
|
||||||
llvm::StringMap<AbstractOperation> ®isteredOps =
|
llvm::StringMap<AbstractOperation> ®isteredOps =
|
||||||
impl->registeredOperations;
|
impl->registeredOperations;
|
||||||
|
|
||||||
opsToSort.reserve(registeredOps.size());
|
|
||||||
for (auto &elt : registeredOps)
|
|
||||||
opsToSort.push_back({elt.first(), &elt.second});
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
|
|
||||||
|
|
||||||
std::vector<AbstractOperation *> result;
|
std::vector<AbstractOperation *> result;
|
||||||
result.reserve(opsToSort.size());
|
result.reserve(registeredOps.size());
|
||||||
for (auto &elt : opsToSort)
|
for (auto &elt : registeredOps)
|
||||||
result.push_back(elt.second);
|
result.push_back(&elt.second);
|
||||||
|
llvm::array_pod_sort(
|
||||||
|
result.begin(), result.end(),
|
||||||
|
[](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
|
||||||
|
return (*lhs)->name.compare((*rhs)->name);
|
||||||
|
});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MLIRContext::isOperationRegistered(StringRef name) {
|
bool MLIRContext::isOperationRegistered(StringRef name) {
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
|
|
||||||
|
|
||||||
return impl->registeredOperations.count(name);
|
return impl->registeredOperations.count(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,12 +540,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
||||||
"op name doesn't start with dialect namespace");
|
"op name doesn't start with dialect namespace");
|
||||||
assert(&opInfo.dialect == this && "Dialect object mismatch");
|
assert(&opInfo.dialect == this && "Dialect object mismatch");
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Lock access to the context registry.
|
|
||||||
StringRef opName = opInfo.name;
|
StringRef opName = opInfo.name;
|
||||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
|
||||||
if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
|
if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
|
||||||
llvm::errs() << "error: operation named '" << opName
|
llvm::errs() << "error: operation named '" << opInfo.name
|
||||||
<< "' is already registered.\n";
|
<< "' is already registered.\n";
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
@ -574,9 +550,6 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
||||||
|
|
||||||
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
|
||||||
auto *newInfo =
|
auto *newInfo =
|
||||||
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
|
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
|
||||||
AbstractType(std::move(typeInfo));
|
AbstractType(std::move(typeInfo));
|
||||||
|
@ -586,9 +559,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
||||||
|
|
||||||
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
|
||||||
auto *newInfo =
|
auto *newInfo =
|
||||||
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
|
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
|
||||||
AbstractAttribute(std::move(attrInfo));
|
AbstractAttribute(std::move(attrInfo));
|
||||||
|
@ -612,9 +582,6 @@ const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
|
||||||
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
|
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Lock access to the context registry.
|
|
||||||
ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
|
|
||||||
auto it = impl.registeredOperations.find(opName);
|
auto it = impl.registeredOperations.find(opName);
|
||||||
if (it != impl.registeredOperations.end())
|
if (it != impl.registeredOperations.end())
|
||||||
return &it->second;
|
return &it->second;
|
||||||
|
|
Loading…
Reference in New Issue