diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 6f35ac24180e..53ed89230bdf 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -29,10 +29,6 @@ class MLIRContextImpl; class Location; class Dialect; -namespace detail { -class TypeUniquer; -} - /// MLIRContext is the top-level object for a collection of MLIR modules. It /// holds immortal uniqued objects like types, and the tables used to unique /// them. @@ -97,9 +93,6 @@ public: // MLIRContextImpl type. MLIRContextImpl &getImpl() const { return *impl.get(); } - /// Get the type uniquer for this context. - detail::TypeUniquer &getTypeUniquer() const; - private: const std::unique_ptr impl; diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h index d1521f4689e3..d92efd6e27e2 100644 --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -110,18 +110,16 @@ using DefaultTypeStorage = TypeStorage; // TypeStorageAllocator //===----------------------------------------------------------------------===// -// This is a utility allocator used to allocate memory for derived types that -// will be tied to the current MLIRContext. +// This is a utility allocator used to allocate memory for instances of derived +// Types. class TypeStorageAllocator { public: - TypeStorageAllocator(MLIRContext *ctx) : ctx(ctx) {} - /// Copy the specified array of elements into memory managed by our bump /// pointer allocator. This assumes the elements are all PODs. template ArrayRef copyInto(ArrayRef elements) { if (elements.empty()) return llvm::None; - auto result = getAllocator().Allocate(elements.size()); + auto result = allocator.Allocate(elements.size()); std::uninitialized_copy(elements.begin(), elements.end(), result); return ArrayRef(result, elements.size()); } @@ -134,13 +132,11 @@ public: } // Allocate an instance of the provided type. - template T *allocate() { return getAllocator().Allocate(); } + template T *allocate() { return allocator.Allocate(); } private: - /// Get a reference to the internal allocator. - llvm::BumpPtrAllocator &getAllocator(); - - MLIRContext *ctx; + /// The raw allocator for type storage objects. + llvm::BumpPtrAllocator allocator; }; //===----------------------------------------------------------------------===// @@ -151,38 +147,12 @@ namespace detail { // MLIRContext. This class manages all creation and uniquing of types. class TypeUniquer { public: - template - static T get(MLIRContext *ctx, Args &&... args) { - TypeUniquer &instance = ctx->getTypeUniquer(); - return instance.getImpl(ctx, args...); - } - -private: - /// A utility wrapper object representing a hashed storage object. This class - /// contains a storage object and an existing computed hash value. - struct HashedStorageType { - unsigned hashValue; - TypeStorage *storage; - }; - - /// A lookup key for derived instances of TypeStorage objects. - struct TypeLookupKey { - /// The known derived kind for the storage. - unsigned kind; - - /// The known hash value of the key. - unsigned hashValue; - - /// An equality function for comparing with an existing storage instance. - llvm::function_ref isEqual; - }; - /// Get an uniqued instance of a type T. This overload is used for derived /// types that have complex storage or uniquing constraints. template - typename std::enable_if< + static typename std::enable_if< !std::is_same::value, T>::type - getImpl(MLIRContext *ctx, unsigned kind, Args &&... args) { + get(MLIRContext *ctx, unsigned kind, Args &&... args) { using ImplType = typename T::ImplType; using KeyTy = typename ImplType::KeyTy; @@ -198,41 +168,46 @@ private: return static_cast(*existing) == derivedKey; }; - // Look to see if the type has been created already. - auto existing = - storageTypes.insert_as({}, TypeLookupKey{kind, hashValue, isEqual}); + // Generate a constructor function for the derived storage. + std::function constructorFn = + [&](TypeStorageAllocator &allocator) { + TypeStorage *storage = ImplType::construct(allocator, derivedKey); + storage->initializeTypeInfo(lookupDialectForType(ctx), kind); + return storage; + }; - // If it has been created, return it. - if (!existing.second) - return T(existing.first->storage); - - // Otherwise, construct and initialize the derived storage for this type - // instance. - TypeStorageAllocator allocator(ctx); - TypeStorage *storage = ImplType::construct(allocator, derivedKey); - storage->initializeTypeInfo(lookupDialectForType(ctx), kind); - *existing.first = HashedStorageType{hashValue, storage}; - return T(storage); + // Get an instance for the derived storage. + return T(getImpl(ctx, kind, hashValue, isEqual, constructorFn)); } /// Get an uniqued instance of a type T. This overload is used for derived /// types that use the DefaultTypeStorage and thus need no additional storage /// or uniquing. template - typename std::enable_if< + static typename std::enable_if< std::is_same::value, T>::type - getImpl(MLIRContext *ctx, unsigned kind) { - // Check for an existing instance with this kind. - auto *&result = simpleTypes[kind]; - if (!result) { - // Otherwise, allocate and initialize one. - TypeStorageAllocator allocator(ctx); - result = new (allocator.allocate()) + get(MLIRContext *ctx, unsigned kind) { + auto constructorFn = [=](TypeStorageAllocator &allocator) { + return new (allocator.allocate()) DefaultTypeStorage(lookupDialectForType(ctx), kind); - } - return T(result); + }; + return T(getImpl(ctx, kind, constructorFn)); } +private: + /// Implementation for getting/creating an instance of a derived type with + /// complex storage. + static TypeStorage * + getImpl(MLIRContext *ctx, unsigned kind, unsigned hashValue, + llvm::function_ref isEqual, + std::function constructorFn); + + /// Implementation for getting/creating an instance of a derived type with + /// default storage. + static TypeStorage * + getImpl(MLIRContext *ctx, unsigned kind, + std::function constructorFn); + /// Get the dialect that the type 'T' was registered with. template static const Dialect &lookupDialectForType(MLIRContext *ctx) { @@ -315,46 +290,6 @@ private: return llvm::hash_combine( kind, llvm::DenseMapInfo::getHashValue(derivedKey)); } - - //===--------------------------------------------------------------------===// - // Instance Storage - //===--------------------------------------------------------------------===// - - /// Storage info for derived TypeStorage objects. - struct StorageKeyInfo : DenseMapInfo { - static HashedStorageType getEmptyKey() { - return HashedStorageType{0, DenseMapInfo::getEmptyKey()}; - } - static HashedStorageType getTombstoneKey() { - return HashedStorageType{0, - DenseMapInfo::getTombstoneKey()}; - } - - static unsigned getHashValue(const HashedStorageType &key) { - return key.hashValue; - } - static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; } - - static bool isEqual(const HashedStorageType &lhs, - const HashedStorageType &rhs) { - return lhs.storage == rhs.storage; - } - static bool isEqual(const TypeLookupKey &lhs, - const HashedStorageType &rhs) { - if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) - return false; - // If the lookup kind matches the kind of the storage, then invoke the - // equality function on the lookup key. - return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); - } - }; - - // Unique types with specific hashing or storage constraints. - using StorageTypeSet = llvm::DenseSet; - StorageTypeSet storageTypes; - - // Unique types with just the kind. - DenseMap simpleTypes; }; } // namespace detail diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index c8bbb3d644b9..4a042a6658fc 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Allocator.h" +#include "llvm/Support/RWMutex.h" #include "llvm/Support/raw_ostream.h" #include @@ -280,6 +281,127 @@ struct FusedLocKeyInfo : DenseMapInfo { return lhs == std::make_pair(rhs->getLocations(), rhs->metadata); } }; + +/// This is the implementation of the TypeUniquer class. +struct TypeUniquerImpl { + /// A lookup key for derived instances of TypeStorage objects. + struct TypeLookupKey { + /// The known derived kind for the storage. + unsigned kind; + + /// The known hash value of the key. + unsigned hashValue; + + /// An equality function for comparing with an existing storage instance. + llvm::function_ref isEqual; + }; + + /// A utility wrapper object representing a hashed storage object. This class + /// contains a storage object and an existing computed hash value. + struct HashedStorageType { + unsigned hashValue; + TypeStorage *storage; + }; + + /// Get or create an instance of a complex derived type. + TypeStorage *getOrCreate( + unsigned kind, unsigned hashValue, + llvm::function_ref isEqual, + std::function constructorFn) { + TypeLookupKey lookupKey{kind, hashValue, isEqual}; + + { // Check for an existing instance in read-only mode. + llvm::sys::SmartScopedReader typeLock(typeMutex); + auto it = storageTypes.find_as(lookupKey); + if (it != storageTypes.end()) + return it->storage; + } + + // Aquire a writer-lock so that we can safely create the new type instance. + llvm::sys::SmartScopedWriter typeLock(typeMutex); + + // Check for an existing instance again here, because another writer thread + // may have already created one. + auto existing = storageTypes.insert_as({}, lookupKey); + if (!existing.second) + return existing.first->storage; + + // Otherwise, construct and initialize the derived storage for this type + // instance. + TypeStorage *storage = constructorFn(allocator); + *existing.first = HashedStorageType{hashValue, storage}; + return storage; + } + + /// Get or create an instance of a simple derived type. + TypeStorage *getOrCreate( + unsigned kind, + std::function constructorFn) { + { // Check if the type already exists in read-only mode. + llvm::sys::SmartScopedReader typeLock(typeMutex); + auto it = simpleTypes.find(kind); + if (it != simpleTypes.end()) + return it->second; + } + + // Aquire the mutex in write mode so that we can safely construct the new + // instance. + llvm::sys::SmartScopedWriter typeLock(typeMutex); + + // Check for an existing instance again here, because another writer thread + // may have already created one. + auto *&result = simpleTypes[kind]; + if (!result) + result = constructorFn(allocator); + return result; + } + + //===--------------------------------------------------------------------===// + // Instance Storage + //===--------------------------------------------------------------------===// + + /// Storage info for derived TypeStorage objects. + struct StorageKeyInfo : DenseMapInfo { + static HashedStorageType getEmptyKey() { + return HashedStorageType{0, DenseMapInfo::getEmptyKey()}; + } + static HashedStorageType getTombstoneKey() { + return HashedStorageType{0, + DenseMapInfo::getTombstoneKey()}; + } + + static unsigned getHashValue(const HashedStorageType &key) { + return key.hashValue; + } + static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; } + + static bool isEqual(const HashedStorageType &lhs, + const HashedStorageType &rhs) { + return lhs.storage == rhs.storage; + } + static bool isEqual(const TypeLookupKey &lhs, + const HashedStorageType &rhs) { + if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) + return false; + // If the lookup kind matches the kind of the storage, then invoke the + // equality function on the lookup key. + return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); + } + }; + + // Unique types with specific hashing or storage constraints. + using StorageTypeSet = llvm::DenseSet; + StorageTypeSet storageTypes; + + // Unique types with just the kind. + DenseMap simpleTypes; + + // Allocator to use when constructing derived type instances. + TypeStorageAllocator allocator; + + // A mutex to keep type uniquing thread-safe. + llvm::sys::SmartRWMutex typeMutex; +}; } // end anonymous namespace. namespace mlir { @@ -354,7 +476,7 @@ public: DenseMap constExprs; /// Type uniquing. - TypeUniquer typeUniquer; + TypeUniquerImpl typeUniquer; // Attribute uniquing. BoolAttributeStorage *boolAttrs[2] = {nullptr}; @@ -701,14 +823,22 @@ Location FusedLoc::get(ArrayRef locs, Attribute metadata, // Type uniquing //===----------------------------------------------------------------------===// -/// Get the type uniquer for this context. -TypeUniquer &MLIRContext::getTypeUniquer() const { - return getImpl().typeUniquer; +/// Implementation for getting/creating an instance of a derived type with +/// complex storage. +TypeStorage *TypeUniquer::getImpl( + MLIRContext *ctx, unsigned kind, unsigned hashValue, + llvm::function_ref isEqual, + std::function constructorFn) { + return ctx->getImpl().typeUniquer.getOrCreate(kind, hashValue, isEqual, + constructorFn); } -/// Get a reference to the internal allocator. -llvm::BumpPtrAllocator &TypeStorageAllocator::getAllocator() { - return ctx->getImpl().allocator; +/// Implementation for getting/creating an instance of a derived type with +/// default storage. +TypeStorage *TypeUniquer::getImpl( + MLIRContext *ctx, unsigned kind, + std::function constructorFn) { + return ctx->getImpl().typeUniquer.getOrCreate(kind, constructorFn); } /// Get the dialect that registered the type with the provided typeid.