Introduce a new DialectIdentifier structure, extending Identifier with a Dialect information

This class is looking up a dialect prefix on the identifier on initialization
and keeping a pointer to the Dialect when found.

The NamedAttribute key is now a DialectIdentifier.

Reviewed By: rriddle, jpienaar

Differential Revision: https://reviews.llvm.org/D95418
This commit is contained in:
Mehdi Amini 2021-01-29 00:05:26 +00:00
parent ab2d3ce47d
commit e9dc94291e
2 changed files with 51 additions and 6 deletions

View File

@ -11,10 +11,12 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringMapEntry.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
namespace mlir {
class Dialect;
class MLIRContext;
/// This class represents a uniqued string owned by an MLIRContext. Strings
@ -22,10 +24,17 @@ class MLIRContext;
/// zero length.
///
/// This is a POD type with pointer size, so it should be passed around by
/// value. The underlying data is owned by MLIRContext and is thus immortal for
/// value. The underlying data is owned by MLIRContext and is thus immortal for
/// almost all clients.
///
/// An Identifier may be prefixed with a dialect namespace followed by a single
/// dot `.`. This is particularly useful when used as a key in a NamedAttribute
/// to differentiate a dependent attribute (specific to an operation) from a
/// generic attribute defined by the dialect (in general applicable to multiple
/// operations).
class Identifier {
using EntryType = llvm::StringMapEntry<llvm::NoneType>;
using EntryType =
llvm::StringMapEntry<PointerUnion<Dialect *, MLIRContext *>>;
public:
/// Return an identifier for the specified string.
@ -51,6 +60,15 @@ public:
/// Return the number of bytes in this string.
unsigned size() const { return entry->getKeyLength(); }
/// Return the dialect loaded in the context for this identifier or nullptr if
/// this identifier isn't prefixed with a loaded dialect. For example the
/// `llvm.fastmathflags` identifier would return the LLVM dialect here,
/// assuming it is loaded in the context.
Dialect *getDialect();
/// Return the current MLIRContext associated with this identifier.
MLIRContext *getContext();
const char *begin() const { return data(); }
const char *end() const { return entry->getKeyData() + size(); }

View File

@ -264,9 +264,12 @@ public:
/// Identifiers are uniqued by string value and use the internal string set
/// for storage.
llvm::StringSet<llvm::BumpPtrAllocator &> identifiers;
llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>,
llvm::BumpPtrAllocator &>
identifiers;
/// A thread local cache of identifiers to reduce lock contention.
ThreadLocalCache<llvm::StringMap<llvm::StringMapEntry<llvm::NoneType> *>>
ThreadLocalCache<llvm::StringMap<
llvm::StringMapEntry<PointerUnion<Dialect *, MLIRContext *>> *>>
localIdentifierCache;
/// An allocator used for AbstractAttribute and AbstractType objects.
@ -481,6 +484,14 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
#endif
dialect = ctor();
assert(dialect && "dialect ctor failed");
// Refresh all the identifiers dialect field, this catches cases where a
// dialect may be loaded after identifier prefixed with this dialect name
// were already created.
for (auto &identifierEntry : impl.identifiers)
if (identifierEntry.first().startswith(dialectNamespace))
identifierEntry.second = dialect.get();
return dialect.get();
}
@ -707,9 +718,15 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
assert(str.find('\0') == StringRef::npos &&
"Cannot create an identifier with a nul character");
PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context;
auto dialectNamePair = str.split('.');
if (!dialectNamePair.first.empty())
if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first))
dialectOrContext = dialect;
auto &impl = context->getImpl();
if (!context->isMultithreadingEnabled())
return Identifier(&*impl.identifiers.insert(str).first);
return Identifier(&*impl.identifiers.insert({str, dialectOrContext}).first);
// Check for an existing instance in the local cache.
auto *&localEntry = (*impl.localIdentifierCache)[str];
@ -728,11 +745,21 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
// Acquire a writer-lock so that we can safely create the new instance.
llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.insert(str).first;
auto it = impl.identifiers.insert({str, dialectOrContext}).first;
localEntry = &*it;
return Identifier(localEntry);
}
Dialect *Identifier::getDialect() {
return entry->second.dyn_cast<Dialect *>();
}
MLIRContext *Identifier::getContext() {
if (Dialect *dialect = getDialect())
return dialect->getContext();
return entry->second.get<MLIRContext *>();
}
//===----------------------------------------------------------------------===//
// Type uniquing
//===----------------------------------------------------------------------===//