forked from OSchip/llvm-project
Give the TypeUniquer its own BumpPtrAllocator and a SmartRWMutex to make it thread-safe. This is step 1/N to making the MLIRContext thread-safe.
PiperOrigin-RevId: 238037814
This commit is contained in:
parent
739f3ef7ee
commit
7eee76b84c
|
@ -29,10 +29,6 @@ class MLIRContextImpl;
|
|||
class Location;
|
||||
class Dialect;
|
||||
|
||||
namespace detail {
|
||||
class TypeUniquer;
|
||||
}
|
||||
|
||||
/// MLIRContext is the top-level object for a collection of MLIR modules. It
|
||||
/// holds immortal uniqued objects like types, and the tables used to unique
|
||||
/// them.
|
||||
|
@ -97,9 +93,6 @@ public:
|
|||
// MLIRContextImpl type.
|
||||
MLIRContextImpl &getImpl() const { return *impl.get(); }
|
||||
|
||||
/// Get the type uniquer for this context.
|
||||
detail::TypeUniquer &getTypeUniquer() const;
|
||||
|
||||
private:
|
||||
const std::unique_ptr<MLIRContextImpl> impl;
|
||||
|
||||
|
|
|
@ -110,18 +110,16 @@ using DefaultTypeStorage = TypeStorage;
|
|||
// TypeStorageAllocator
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// This is a utility allocator used to allocate memory for derived types that
|
||||
// will be tied to the current MLIRContext.
|
||||
// This is a utility allocator used to allocate memory for instances of derived
|
||||
// Types.
|
||||
class TypeStorageAllocator {
|
||||
public:
|
||||
TypeStorageAllocator(MLIRContext *ctx) : ctx(ctx) {}
|
||||
|
||||
/// Copy the specified array of elements into memory managed by our bump
|
||||
/// pointer allocator. This assumes the elements are all PODs.
|
||||
template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
|
||||
if (elements.empty())
|
||||
return llvm::None;
|
||||
auto result = getAllocator().Allocate<T>(elements.size());
|
||||
auto result = allocator.Allocate<T>(elements.size());
|
||||
std::uninitialized_copy(elements.begin(), elements.end(), result);
|
||||
return ArrayRef<T>(result, elements.size());
|
||||
}
|
||||
|
@ -134,13 +132,11 @@ public:
|
|||
}
|
||||
|
||||
// Allocate an instance of the provided type.
|
||||
template <typename T> T *allocate() { return getAllocator().Allocate<T>(); }
|
||||
template <typename T> T *allocate() { return allocator.Allocate<T>(); }
|
||||
|
||||
private:
|
||||
/// Get a reference to the internal allocator.
|
||||
llvm::BumpPtrAllocator &getAllocator();
|
||||
|
||||
MLIRContext *ctx;
|
||||
/// The raw allocator for type storage objects.
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -151,38 +147,12 @@ namespace detail {
|
|||
// MLIRContext. This class manages all creation and uniquing of types.
|
||||
class TypeUniquer {
|
||||
public:
|
||||
template <typename T, typename... Args>
|
||||
static T get(MLIRContext *ctx, Args &&... args) {
|
||||
TypeUniquer &instance = ctx->getTypeUniquer();
|
||||
return instance.getImpl<T>(ctx, args...);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A utility wrapper object representing a hashed storage object. This class
|
||||
/// contains a storage object and an existing computed hash value.
|
||||
struct HashedStorageType {
|
||||
unsigned hashValue;
|
||||
TypeStorage *storage;
|
||||
};
|
||||
|
||||
/// A lookup key for derived instances of TypeStorage objects.
|
||||
struct TypeLookupKey {
|
||||
/// The known derived kind for the storage.
|
||||
unsigned kind;
|
||||
|
||||
/// The known hash value of the key.
|
||||
unsigned hashValue;
|
||||
|
||||
/// An equality function for comparing with an existing storage instance.
|
||||
llvm::function_ref<bool(const TypeStorage *)> isEqual;
|
||||
};
|
||||
|
||||
/// Get an uniqued instance of a type T. This overload is used for derived
|
||||
/// types that have complex storage or uniquing constraints.
|
||||
template <typename T, typename... Args>
|
||||
typename std::enable_if<
|
||||
static typename std::enable_if<
|
||||
!std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
|
||||
getImpl(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
||||
get(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
||||
using ImplType = typename T::ImplType;
|
||||
using KeyTy = typename ImplType::KeyTy;
|
||||
|
||||
|
@ -198,41 +168,46 @@ private:
|
|||
return static_cast<const ImplType &>(*existing) == derivedKey;
|
||||
};
|
||||
|
||||
// Look to see if the type has been created already.
|
||||
auto existing =
|
||||
storageTypes.insert_as({}, TypeLookupKey{kind, hashValue, isEqual});
|
||||
// Generate a constructor function for the derived storage.
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn =
|
||||
[&](TypeStorageAllocator &allocator) {
|
||||
TypeStorage *storage = ImplType::construct(allocator, derivedKey);
|
||||
storage->initializeTypeInfo(lookupDialectForType<T>(ctx), kind);
|
||||
return storage;
|
||||
};
|
||||
|
||||
// If it has been created, return it.
|
||||
if (!existing.second)
|
||||
return T(existing.first->storage);
|
||||
|
||||
// Otherwise, construct and initialize the derived storage for this type
|
||||
// instance.
|
||||
TypeStorageAllocator allocator(ctx);
|
||||
TypeStorage *storage = ImplType::construct(allocator, derivedKey);
|
||||
storage->initializeTypeInfo(lookupDialectForType<T>(ctx), kind);
|
||||
*existing.first = HashedStorageType{hashValue, storage};
|
||||
return T(storage);
|
||||
// Get an instance for the derived storage.
|
||||
return T(getImpl(ctx, kind, hashValue, isEqual, constructorFn));
|
||||
}
|
||||
|
||||
/// Get an uniqued instance of a type T. This overload is used for derived
|
||||
/// types that use the DefaultTypeStorage and thus need no additional storage
|
||||
/// or uniquing.
|
||||
template <typename T, typename... Args>
|
||||
typename std::enable_if<
|
||||
static typename std::enable_if<
|
||||
std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
|
||||
getImpl(MLIRContext *ctx, unsigned kind) {
|
||||
// Check for an existing instance with this kind.
|
||||
auto *&result = simpleTypes[kind];
|
||||
if (!result) {
|
||||
// Otherwise, allocate and initialize one.
|
||||
TypeStorageAllocator allocator(ctx);
|
||||
result = new (allocator.allocate<DefaultTypeStorage>())
|
||||
get(MLIRContext *ctx, unsigned kind) {
|
||||
auto constructorFn = [=](TypeStorageAllocator &allocator) {
|
||||
return new (allocator.allocate<DefaultTypeStorage>())
|
||||
DefaultTypeStorage(lookupDialectForType<T>(ctx), kind);
|
||||
}
|
||||
return T(result);
|
||||
};
|
||||
return T(getImpl(ctx, kind, constructorFn));
|
||||
}
|
||||
|
||||
private:
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// complex storage.
|
||||
static TypeStorage *
|
||||
getImpl(MLIRContext *ctx, unsigned kind, unsigned hashValue,
|
||||
llvm::function_ref<bool(const TypeStorage *)> isEqual,
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
|
||||
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// default storage.
|
||||
static TypeStorage *
|
||||
getImpl(MLIRContext *ctx, unsigned kind,
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
|
||||
|
||||
/// Get the dialect that the type 'T' was registered with.
|
||||
template <typename T>
|
||||
static const Dialect &lookupDialectForType(MLIRContext *ctx) {
|
||||
|
@ -315,46 +290,6 @@ private:
|
|||
return llvm::hash_combine(
|
||||
kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instance Storage
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Storage info for derived TypeStorage objects.
|
||||
struct StorageKeyInfo : DenseMapInfo<HashedStorageType> {
|
||||
static HashedStorageType getEmptyKey() {
|
||||
return HashedStorageType{0, DenseMapInfo<TypeStorage *>::getEmptyKey()};
|
||||
}
|
||||
static HashedStorageType getTombstoneKey() {
|
||||
return HashedStorageType{0,
|
||||
DenseMapInfo<TypeStorage *>::getTombstoneKey()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const HashedStorageType &key) {
|
||||
return key.hashValue;
|
||||
}
|
||||
static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; }
|
||||
|
||||
static bool isEqual(const HashedStorageType &lhs,
|
||||
const HashedStorageType &rhs) {
|
||||
return lhs.storage == rhs.storage;
|
||||
}
|
||||
static bool isEqual(const TypeLookupKey &lhs,
|
||||
const HashedStorageType &rhs) {
|
||||
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
|
||||
return false;
|
||||
// If the lookup kind matches the kind of the storage, then invoke the
|
||||
// equality function on the lookup key.
|
||||
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
|
||||
}
|
||||
};
|
||||
|
||||
// Unique types with specific hashing or storage constraints.
|
||||
using StorageTypeSet = llvm::DenseSet<HashedStorageType, StorageKeyInfo>;
|
||||
StorageTypeSet storageTypes;
|
||||
|
||||
// Unique types with just the kind.
|
||||
DenseMap<unsigned, TypeStorage *> simpleTypes;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/RWMutex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <memory>
|
||||
|
||||
|
@ -280,6 +281,127 @@ struct FusedLocKeyInfo : DenseMapInfo<FusedLocationStorage *> {
|
|||
return lhs == std::make_pair(rhs->getLocations(), rhs->metadata);
|
||||
}
|
||||
};
|
||||
|
||||
/// This is the implementation of the TypeUniquer class.
|
||||
struct TypeUniquerImpl {
|
||||
/// A lookup key for derived instances of TypeStorage objects.
|
||||
struct TypeLookupKey {
|
||||
/// The known derived kind for the storage.
|
||||
unsigned kind;
|
||||
|
||||
/// The known hash value of the key.
|
||||
unsigned hashValue;
|
||||
|
||||
/// An equality function for comparing with an existing storage instance.
|
||||
llvm::function_ref<bool(const TypeStorage *)> isEqual;
|
||||
};
|
||||
|
||||
/// A utility wrapper object representing a hashed storage object. This class
|
||||
/// contains a storage object and an existing computed hash value.
|
||||
struct HashedStorageType {
|
||||
unsigned hashValue;
|
||||
TypeStorage *storage;
|
||||
};
|
||||
|
||||
/// Get or create an instance of a complex derived type.
|
||||
TypeStorage *getOrCreate(
|
||||
unsigned kind, unsigned hashValue,
|
||||
llvm::function_ref<bool(const TypeStorage *)> isEqual,
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
|
||||
TypeLookupKey lookupKey{kind, hashValue, isEqual};
|
||||
|
||||
{ // Check for an existing instance in read-only mode.
|
||||
llvm::sys::SmartScopedReader<true> typeLock(typeMutex);
|
||||
auto it = storageTypes.find_as(lookupKey);
|
||||
if (it != storageTypes.end())
|
||||
return it->storage;
|
||||
}
|
||||
|
||||
// Aquire a writer-lock so that we can safely create the new type instance.
|
||||
llvm::sys::SmartScopedWriter<true> typeLock(typeMutex);
|
||||
|
||||
// Check for an existing instance again here, because another writer thread
|
||||
// may have already created one.
|
||||
auto existing = storageTypes.insert_as({}, lookupKey);
|
||||
if (!existing.second)
|
||||
return existing.first->storage;
|
||||
|
||||
// Otherwise, construct and initialize the derived storage for this type
|
||||
// instance.
|
||||
TypeStorage *storage = constructorFn(allocator);
|
||||
*existing.first = HashedStorageType{hashValue, storage};
|
||||
return storage;
|
||||
}
|
||||
|
||||
/// Get or create an instance of a simple derived type.
|
||||
TypeStorage *getOrCreate(
|
||||
unsigned kind,
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
|
||||
{ // Check if the type already exists in read-only mode.
|
||||
llvm::sys::SmartScopedReader<true> typeLock(typeMutex);
|
||||
auto it = simpleTypes.find(kind);
|
||||
if (it != simpleTypes.end())
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Aquire the mutex in write mode so that we can safely construct the new
|
||||
// instance.
|
||||
llvm::sys::SmartScopedWriter<true> typeLock(typeMutex);
|
||||
|
||||
// Check for an existing instance again here, because another writer thread
|
||||
// may have already created one.
|
||||
auto *&result = simpleTypes[kind];
|
||||
if (!result)
|
||||
result = constructorFn(allocator);
|
||||
return result;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instance Storage
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Storage info for derived TypeStorage objects.
|
||||
struct StorageKeyInfo : DenseMapInfo<HashedStorageType> {
|
||||
static HashedStorageType getEmptyKey() {
|
||||
return HashedStorageType{0, DenseMapInfo<TypeStorage *>::getEmptyKey()};
|
||||
}
|
||||
static HashedStorageType getTombstoneKey() {
|
||||
return HashedStorageType{0,
|
||||
DenseMapInfo<TypeStorage *>::getTombstoneKey()};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const HashedStorageType &key) {
|
||||
return key.hashValue;
|
||||
}
|
||||
static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; }
|
||||
|
||||
static bool isEqual(const HashedStorageType &lhs,
|
||||
const HashedStorageType &rhs) {
|
||||
return lhs.storage == rhs.storage;
|
||||
}
|
||||
static bool isEqual(const TypeLookupKey &lhs,
|
||||
const HashedStorageType &rhs) {
|
||||
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
|
||||
return false;
|
||||
// If the lookup kind matches the kind of the storage, then invoke the
|
||||
// equality function on the lookup key.
|
||||
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
|
||||
}
|
||||
};
|
||||
|
||||
// Unique types with specific hashing or storage constraints.
|
||||
using StorageTypeSet = llvm::DenseSet<HashedStorageType, StorageKeyInfo>;
|
||||
StorageTypeSet storageTypes;
|
||||
|
||||
// Unique types with just the kind.
|
||||
DenseMap<unsigned, TypeStorage *> simpleTypes;
|
||||
|
||||
// Allocator to use when constructing derived type instances.
|
||||
TypeStorageAllocator allocator;
|
||||
|
||||
// A mutex to keep type uniquing thread-safe.
|
||||
llvm::sys::SmartRWMutex<true> typeMutex;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
namespace mlir {
|
||||
|
@ -354,7 +476,7 @@ public:
|
|||
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
|
||||
|
||||
/// Type uniquing.
|
||||
TypeUniquer typeUniquer;
|
||||
TypeUniquerImpl typeUniquer;
|
||||
|
||||
// Attribute uniquing.
|
||||
BoolAttributeStorage *boolAttrs[2] = {nullptr};
|
||||
|
@ -701,14 +823,22 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
|
|||
// Type uniquing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Get the type uniquer for this context.
|
||||
TypeUniquer &MLIRContext::getTypeUniquer() const {
|
||||
return getImpl().typeUniquer;
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// complex storage.
|
||||
TypeStorage *TypeUniquer::getImpl(
|
||||
MLIRContext *ctx, unsigned kind, unsigned hashValue,
|
||||
llvm::function_ref<bool(const TypeStorage *)> isEqual,
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
|
||||
return ctx->getImpl().typeUniquer.getOrCreate(kind, hashValue, isEqual,
|
||||
constructorFn);
|
||||
}
|
||||
|
||||
/// Get a reference to the internal allocator.
|
||||
llvm::BumpPtrAllocator &TypeStorageAllocator::getAllocator() {
|
||||
return ctx->getImpl().allocator;
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// default storage.
|
||||
TypeStorage *TypeUniquer::getImpl(
|
||||
MLIRContext *ctx, unsigned kind,
|
||||
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
|
||||
return ctx->getImpl().typeUniquer.getOrCreate(kind, constructorFn);
|
||||
}
|
||||
|
||||
/// Get the dialect that registered the type with the provided typeid.
|
||||
|
|
Loading…
Reference in New Issue