Give the TypeUniquer its own BumpPtrAllocator and a SmartRWMutex to make it thread-safe. This is step 1/N to making the MLIRContext thread-safe.

PiperOrigin-RevId: 238037814
This commit is contained in:
River Riddle 2019-03-12 10:00:21 -07:00 committed by jpienaar
parent 739f3ef7ee
commit 7eee76b84c
3 changed files with 174 additions and 116 deletions

View File

@ -29,10 +29,6 @@ class MLIRContextImpl;
class Location;
class Dialect;
namespace detail {
class TypeUniquer;
}
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
/// them.
@ -97,9 +93,6 @@ public:
// MLIRContextImpl type.
MLIRContextImpl &getImpl() const { return *impl.get(); }
/// Get the type uniquer for this context.
detail::TypeUniquer &getTypeUniquer() const;
private:
const std::unique_ptr<MLIRContextImpl> impl;

View File

@ -110,18 +110,16 @@ using DefaultTypeStorage = TypeStorage;
// TypeStorageAllocator
//===----------------------------------------------------------------------===//
// This is a utility allocator used to allocate memory for derived types that
// will be tied to the current MLIRContext.
// This is a utility allocator used to allocate memory for instances of derived
// Types.
class TypeStorageAllocator {
public:
TypeStorageAllocator(MLIRContext *ctx) : ctx(ctx) {}
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
if (elements.empty())
return llvm::None;
auto result = getAllocator().Allocate<T>(elements.size());
auto result = allocator.Allocate<T>(elements.size());
std::uninitialized_copy(elements.begin(), elements.end(), result);
return ArrayRef<T>(result, elements.size());
}
@ -134,13 +132,11 @@ public:
}
// Allocate an instance of the provided type.
template <typename T> T *allocate() { return getAllocator().Allocate<T>(); }
template <typename T> T *allocate() { return allocator.Allocate<T>(); }
private:
/// Get a reference to the internal allocator.
llvm::BumpPtrAllocator &getAllocator();
MLIRContext *ctx;
/// The raw allocator for type storage objects.
llvm::BumpPtrAllocator allocator;
};
//===----------------------------------------------------------------------===//
@ -151,38 +147,12 @@ namespace detail {
// MLIRContext. This class manages all creation and uniquing of types.
class TypeUniquer {
public:
template <typename T, typename... Args>
static T get(MLIRContext *ctx, Args &&... args) {
TypeUniquer &instance = ctx->getTypeUniquer();
return instance.getImpl<T>(ctx, args...);
}
private:
/// A utility wrapper object representing a hashed storage object. This class
/// contains a storage object and an existing computed hash value.
struct HashedStorageType {
unsigned hashValue;
TypeStorage *storage;
};
/// A lookup key for derived instances of TypeStorage objects.
struct TypeLookupKey {
/// The known derived kind for the storage.
unsigned kind;
/// The known hash value of the key.
unsigned hashValue;
/// An equality function for comparing with an existing storage instance.
llvm::function_ref<bool(const TypeStorage *)> isEqual;
};
/// Get an uniqued instance of a type T. This overload is used for derived
/// types that have complex storage or uniquing constraints.
template <typename T, typename... Args>
typename std::enable_if<
static typename std::enable_if<
!std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
getImpl(MLIRContext *ctx, unsigned kind, Args &&... args) {
get(MLIRContext *ctx, unsigned kind, Args &&... args) {
using ImplType = typename T::ImplType;
using KeyTy = typename ImplType::KeyTy;
@ -198,41 +168,46 @@ private:
return static_cast<const ImplType &>(*existing) == derivedKey;
};
// Look to see if the type has been created already.
auto existing =
storageTypes.insert_as({}, TypeLookupKey{kind, hashValue, isEqual});
// Generate a constructor function for the derived storage.
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn =
[&](TypeStorageAllocator &allocator) {
TypeStorage *storage = ImplType::construct(allocator, derivedKey);
storage->initializeTypeInfo(lookupDialectForType<T>(ctx), kind);
return storage;
};
// If it has been created, return it.
if (!existing.second)
return T(existing.first->storage);
// Otherwise, construct and initialize the derived storage for this type
// instance.
TypeStorageAllocator allocator(ctx);
TypeStorage *storage = ImplType::construct(allocator, derivedKey);
storage->initializeTypeInfo(lookupDialectForType<T>(ctx), kind);
*existing.first = HashedStorageType{hashValue, storage};
return T(storage);
// Get an instance for the derived storage.
return T(getImpl(ctx, kind, hashValue, isEqual, constructorFn));
}
/// Get an uniqued instance of a type T. This overload is used for derived
/// types that use the DefaultTypeStorage and thus need no additional storage
/// or uniquing.
template <typename T, typename... Args>
typename std::enable_if<
static typename std::enable_if<
std::is_same<typename T::ImplType, DefaultTypeStorage>::value, T>::type
getImpl(MLIRContext *ctx, unsigned kind) {
// Check for an existing instance with this kind.
auto *&result = simpleTypes[kind];
if (!result) {
// Otherwise, allocate and initialize one.
TypeStorageAllocator allocator(ctx);
result = new (allocator.allocate<DefaultTypeStorage>())
get(MLIRContext *ctx, unsigned kind) {
auto constructorFn = [=](TypeStorageAllocator &allocator) {
return new (allocator.allocate<DefaultTypeStorage>())
DefaultTypeStorage(lookupDialectForType<T>(ctx), kind);
}
return T(result);
};
return T(getImpl(ctx, kind, constructorFn));
}
private:
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
static TypeStorage *
getImpl(MLIRContext *ctx, unsigned kind, unsigned hashValue,
llvm::function_ref<bool(const TypeStorage *)> isEqual,
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
/// Implementation for getting/creating an instance of a derived type with
/// default storage.
static TypeStorage *
getImpl(MLIRContext *ctx, unsigned kind,
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn);
/// Get the dialect that the type 'T' was registered with.
template <typename T>
static const Dialect &lookupDialectForType(MLIRContext *ctx) {
@ -315,46 +290,6 @@ private:
return llvm::hash_combine(
kind, llvm::DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
}
//===--------------------------------------------------------------------===//
// Instance Storage
//===--------------------------------------------------------------------===//
/// Storage info for derived TypeStorage objects.
struct StorageKeyInfo : DenseMapInfo<HashedStorageType> {
static HashedStorageType getEmptyKey() {
return HashedStorageType{0, DenseMapInfo<TypeStorage *>::getEmptyKey()};
}
static HashedStorageType getTombstoneKey() {
return HashedStorageType{0,
DenseMapInfo<TypeStorage *>::getTombstoneKey()};
}
static unsigned getHashValue(const HashedStorageType &key) {
return key.hashValue;
}
static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; }
static bool isEqual(const HashedStorageType &lhs,
const HashedStorageType &rhs) {
return lhs.storage == rhs.storage;
}
static bool isEqual(const TypeLookupKey &lhs,
const HashedStorageType &rhs) {
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
return false;
// If the lookup kind matches the kind of the storage, then invoke the
// equality function on the lookup key.
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
}
};
// Unique types with specific hashing or storage constraints.
using StorageTypeSet = llvm::DenseSet<HashedStorageType, StorageKeyInfo>;
StorageTypeSet storageTypes;
// Unique types with just the kind.
DenseMap<unsigned, TypeStorage *> simpleTypes;
};
} // namespace detail

View File

@ -37,6 +37,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
@ -280,6 +281,127 @@ struct FusedLocKeyInfo : DenseMapInfo<FusedLocationStorage *> {
return lhs == std::make_pair(rhs->getLocations(), rhs->metadata);
}
};
/// This is the implementation of the TypeUniquer class.
struct TypeUniquerImpl {
/// A lookup key for derived instances of TypeStorage objects.
struct TypeLookupKey {
/// The known derived kind for the storage.
unsigned kind;
/// The known hash value of the key.
unsigned hashValue;
/// An equality function for comparing with an existing storage instance.
llvm::function_ref<bool(const TypeStorage *)> isEqual;
};
/// A utility wrapper object representing a hashed storage object. This class
/// contains a storage object and an existing computed hash value.
struct HashedStorageType {
unsigned hashValue;
TypeStorage *storage;
};
/// Get or create an instance of a complex derived type.
TypeStorage *getOrCreate(
unsigned kind, unsigned hashValue,
llvm::function_ref<bool(const TypeStorage *)> isEqual,
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
TypeLookupKey lookupKey{kind, hashValue, isEqual};
{ // Check for an existing instance in read-only mode.
llvm::sys::SmartScopedReader<true> typeLock(typeMutex);
auto it = storageTypes.find_as(lookupKey);
if (it != storageTypes.end())
return it->storage;
}
// Aquire a writer-lock so that we can safely create the new type instance.
llvm::sys::SmartScopedWriter<true> typeLock(typeMutex);
// Check for an existing instance again here, because another writer thread
// may have already created one.
auto existing = storageTypes.insert_as({}, lookupKey);
if (!existing.second)
return existing.first->storage;
// Otherwise, construct and initialize the derived storage for this type
// instance.
TypeStorage *storage = constructorFn(allocator);
*existing.first = HashedStorageType{hashValue, storage};
return storage;
}
/// Get or create an instance of a simple derived type.
TypeStorage *getOrCreate(
unsigned kind,
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
{ // Check if the type already exists in read-only mode.
llvm::sys::SmartScopedReader<true> typeLock(typeMutex);
auto it = simpleTypes.find(kind);
if (it != simpleTypes.end())
return it->second;
}
// Aquire the mutex in write mode so that we can safely construct the new
// instance.
llvm::sys::SmartScopedWriter<true> typeLock(typeMutex);
// Check for an existing instance again here, because another writer thread
// may have already created one.
auto *&result = simpleTypes[kind];
if (!result)
result = constructorFn(allocator);
return result;
}
//===--------------------------------------------------------------------===//
// Instance Storage
//===--------------------------------------------------------------------===//
/// Storage info for derived TypeStorage objects.
struct StorageKeyInfo : DenseMapInfo<HashedStorageType> {
static HashedStorageType getEmptyKey() {
return HashedStorageType{0, DenseMapInfo<TypeStorage *>::getEmptyKey()};
}
static HashedStorageType getTombstoneKey() {
return HashedStorageType{0,
DenseMapInfo<TypeStorage *>::getTombstoneKey()};
}
static unsigned getHashValue(const HashedStorageType &key) {
return key.hashValue;
}
static unsigned getHashValue(TypeLookupKey key) { return key.hashValue; }
static bool isEqual(const HashedStorageType &lhs,
const HashedStorageType &rhs) {
return lhs.storage == rhs.storage;
}
static bool isEqual(const TypeLookupKey &lhs,
const HashedStorageType &rhs) {
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
return false;
// If the lookup kind matches the kind of the storage, then invoke the
// equality function on the lookup key.
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
}
};
// Unique types with specific hashing or storage constraints.
using StorageTypeSet = llvm::DenseSet<HashedStorageType, StorageKeyInfo>;
StorageTypeSet storageTypes;
// Unique types with just the kind.
DenseMap<unsigned, TypeStorage *> simpleTypes;
// Allocator to use when constructing derived type instances.
TypeStorageAllocator allocator;
// A mutex to keep type uniquing thread-safe.
llvm::sys::SmartRWMutex<true> typeMutex;
};
} // end anonymous namespace.
namespace mlir {
@ -354,7 +476,7 @@ public:
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
/// Type uniquing.
TypeUniquer typeUniquer;
TypeUniquerImpl typeUniquer;
// Attribute uniquing.
BoolAttributeStorage *boolAttrs[2] = {nullptr};
@ -701,14 +823,22 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
// Type uniquing
//===----------------------------------------------------------------------===//
/// Get the type uniquer for this context.
TypeUniquer &MLIRContext::getTypeUniquer() const {
return getImpl().typeUniquer;
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
TypeStorage *TypeUniquer::getImpl(
MLIRContext *ctx, unsigned kind, unsigned hashValue,
llvm::function_ref<bool(const TypeStorage *)> isEqual,
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
return ctx->getImpl().typeUniquer.getOrCreate(kind, hashValue, isEqual,
constructorFn);
}
/// Get a reference to the internal allocator.
llvm::BumpPtrAllocator &TypeStorageAllocator::getAllocator() {
return ctx->getImpl().allocator;
/// Implementation for getting/creating an instance of a derived type with
/// default storage.
TypeStorage *TypeUniquer::getImpl(
MLIRContext *ctx, unsigned kind,
std::function<TypeStorage *(TypeStorageAllocator &)> constructorFn) {
return ctx->getImpl().typeUniquer.getOrCreate(kind, constructorFn);
}
/// Get the dialect that registered the type with the provided typeid.