[mlir] Replace usages of Identifier with StringAttr

Identifier and StringAttr essentially serve the same purpose, i.e. to hold a string value. Keeping these seemingly identical pieces of functionality separate has caused problems in certain situations:

* Identifier has nice accessors that StringAttr doesn't
* Identifier can't be used as an Attribute, meaning strings are often duplicated between Identifier/StringAttr (e.g. in PDL)

The only thing that Identifier has that StringAttr doesn't is support for caching a dialect that is referenced by the string (e.g. dialect.foo). This functionality is added to StringAttr, as this is useful for StringAttr in generally the same ways it was useful for Identifier.

Differential Revision: https://reviews.llvm.org/D113536
This commit is contained in:
River Riddle 2021-11-11 01:44:58 +00:00
parent 7f153e8ba1
commit 120591e126
38 changed files with 252 additions and 327 deletions

View File

@ -30,7 +30,7 @@ DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable)
DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)
DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
DEFINE_C_API_METHODS(MlirType, mlir::Type)

View File

@ -15,6 +15,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StorageUniquerSupport.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/Twine.h"
@ -118,7 +119,7 @@ class alignas(8) AttributeStorage : public StorageUniquer::BaseStorage {
public:
/// Get the type of this attribute.
Type getType() const;
Type getType() const { return type; }
/// Return the abstract descriptor for this attribute.
const AbstractAttribute &getAbstractAttribute() const {
@ -131,24 +132,27 @@ protected:
/// Note: All attributes require a valid type. If no type is provided here,
/// the type of the attribute will automatically default to NoneType
/// upon initialization in the uniquer.
AttributeStorage(Type type);
AttributeStorage();
AttributeStorage(Type type = nullptr) : type(type) {}
/// Set the type of this attribute.
void setType(Type type);
void setType(Type newType) { type = newType; }
// Set the abstract attribute for this storage instance. This is used by the
// AttributeUniquer when initializing a newly constructed storage object.
void initialize(const AbstractAttribute &abstractAttr) {
/// Set the abstract attribute for this storage instance. This is used by the
/// AttributeUniquer when initializing a newly constructed storage object.
void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) {
abstractAttribute = &abstractAttr;
}
/// Default initialization for attribute storage classes that require no
/// additional initialization.
void initialize(MLIRContext *context) {}
private:
/// The type of the attribute value.
Type type;
/// The abstract descriptor for this attribute.
const AbstractAttribute *abstractAttribute;
/// The opaque type of the attribute value.
const void *type;
};
/// Default storage type for attributes that require no additional
@ -188,6 +192,10 @@ public:
return ctx->getAttributeUniquer().get<typename T::ImplType>(
[ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
// Execute any additional attribute storage initialization with the
// context.
static_cast<typename T::ImplType *>(storage)->initialize(ctx);
},
T::getTypeID(), std::forward<Args>(args)...);
}

View File

@ -13,7 +13,10 @@
#include "llvm/Support/PointerLikeTypeTraits.h"
namespace mlir {
class Identifier;
class StringAttr;
// TODO: Remove this when all usages have been replaced with StringAttr.
using Identifier = StringAttr;
/// Attributes are known-constant values of operations.
///
@ -61,7 +64,7 @@ public:
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
/// Return the type of this attribute.
Type getType() const;
Type getType() const { return impl->getType(); }
/// Return the context this attribute belongs to.
MLIRContext *getContext() const;
@ -126,7 +129,7 @@ template <typename U> U Attribute::cast() const {
}
inline ::llvm::hash_code hash_value(Attribute arg) {
return ::llvm::hash_value(arg.impl);
return DenseMapInfo<const Attribute::ImplType *>::getHashValue(arg.impl);
}
//===----------------------------------------------------------------------===//

View File

@ -885,7 +885,35 @@ auto SparseElementsAttr::value_begin() const -> iterator<T> {
};
return iterator<T>(llvm::seq<ptrdiff_t>(0, getNumElements()).begin(), mapFn);
}
} // end namespace mlir.
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
/// Define comparisons for StringAttr against nullptr and itself to avoid the
/// StringRef overloads from being chosen when not desirable.
inline bool operator==(StringAttr lhs, std::nullptr_t) { return !lhs; }
inline bool operator!=(StringAttr lhs, std::nullptr_t) {
return static_cast<bool>(lhs);
}
inline bool operator==(StringAttr lhs, StringAttr rhs) {
return (Attribute)lhs == (Attribute)rhs;
}
inline bool operator!=(StringAttr lhs, StringAttr rhs) { return !(lhs == rhs); }
/// Allow direct comparison with StringRef.
inline bool operator==(StringAttr lhs, StringRef rhs) {
return lhs.getValue() == rhs;
}
inline bool operator!=(StringAttr lhs, StringRef rhs) { return !(lhs == rhs); }
inline bool operator==(StringRef lhs, StringAttr rhs) {
return rhs.getValue() == lhs;
}
inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); }
inline Type StringAttr::getType() const { return Attribute::getType(); }
} // end namespace mlir
//===----------------------------------------------------------------------===//
// Attribute Utilities
@ -893,12 +921,30 @@ auto SparseElementsAttr::value_begin() const -> iterator<T> {
namespace llvm {
template <>
struct DenseMapInfo<mlir::StringAttr> : public DenseMapInfo<mlir::Attribute> {
static mlir::StringAttr getEmptyKey() {
const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
return mlir::StringAttr::getFromOpaquePointer(pointer);
}
static mlir::StringAttr getTombstoneKey() {
const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
return mlir::StringAttr::getFromOpaquePointer(pointer);
}
};
template <>
struct PointerLikeTypeTraits<mlir::StringAttr>
: public PointerLikeTypeTraits<mlir::Attribute> {
static inline mlir::StringAttr getFromVoidPointer(void *p) {
return mlir::StringAttr::getFromOpaquePointer(p);
}
};
template <>
struct PointerLikeTypeTraits<mlir::SymbolRefAttr>
: public PointerLikeTypeTraits<mlir::Attribute> {
static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) {
return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr)
.cast<mlir::SymbolRefAttr>();
return mlir::SymbolRefAttr::getFromOpaquePointer(ptr);
}
};

View File

@ -915,6 +915,44 @@ def Builtin_StringAttr : Builtin_Attr<"String"> {
let extraClassDeclaration = [{
using ValueType = StringRef;
/// If the value of this string is prefixed with a dialect namespace,
/// returns the dialect corresponding to that namespace if it is loaded,
/// nullptr otherwise. For example, the string `llvm.fastmathflags` would
/// return the LLVM dialect, assuming it is loaded in the context.
Dialect *getReferencedDialect() const;
/// Enable conversion to StringRef.
operator StringRef() const { return getValue(); }
/// Returns the underlying string value
StringRef strref() const { return getValue(); }
/// Convert the underling value to an std::string.
std::string str() const { return getValue().str(); }
/// Return a pointer to the start of the string data.
const char *data() const { return getValue().data(); }
/// Return the number of bytes in this string.
size_t size() const { return getValue().size(); }
/// Iterate over the underlying string data.
StringRef::iterator begin() const { return getValue().begin(); }
StringRef::iterator end() const { return getValue().end(); }
/// Compare the underlying string value to the one in `rhs`.
int compare(StringAttr rhs) const {
if (*this == rhs)
return 0;
return getValue().compare(rhs.getValue());
}
/// FIXME: Defined as part of transition of Identifier->StringAttr. Prefer
/// using the other `get` methods instead.
static StringAttr get(const Twine &str, MLIRContext *context) {
return get(context, str);
}
private:
/// Return an empty StringAttr with NoneType type. This is a special variant
/// of the `get` method that is used by the MLIRContext to cache the
@ -923,6 +961,7 @@ def Builtin_StringAttr : Builtin_Attr<"String"> {
friend MLIRContext;
public:
}];
let genStorageClass = 0;
let skipDefaultBuilders = 1;
}

View File

@ -20,11 +20,14 @@ namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
class Identifier;
class IndexType;
class IntegerType;
class StringAttr;
class TypeRange;
// TODO: Remove this when all usages have been replaced with StringAttr.
using Identifier = StringAttr;
//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

View File

@ -24,7 +24,6 @@ class SourceMgr;
namespace mlir {
class DiagnosticEngine;
class Identifier;
struct LogicalResult;
class MLIRContext;
class Operation;
@ -196,6 +195,7 @@ public:
arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
return *this;
}
Diagnostic &operator<<(StringAttr val);
/// Stream in a string literal.
Diagnostic &operator<<(const char *val) {
@ -208,9 +208,6 @@ public:
Diagnostic &operator<<(const Twine &val);
Diagnostic &operator<<(Twine &&val);
/// Stream in an Identifier.
Diagnostic &operator<<(Identifier val);
/// Stream in an OperationName.
Diagnostic &operator<<(OperationName val);

View File

@ -612,7 +612,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError(
"arguments may only have dialect attributes");
if (Dialect *dialect = attr.first.getDialect()) {
if (Dialect *dialect = attr.first.getReferencedDialect()) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr)))
return failure();
@ -645,7 +645,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
for (auto attr : resultAttrs) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("results may only have dialect attributes");
if (Dialect *dialect = attr.first.getDialect()) {
if (Dialect *dialect = attr.first.getReferencedDialect()) {
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
/*resultIndex=*/i,
attr)))

View File

@ -9,151 +9,12 @@
#ifndef MLIR_IR_IDENTIFIER_H
#define MLIR_IR_IDENTIFIER_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringMapEntry.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "mlir/IR/BuiltinAttributes.h"
namespace mlir {
class Dialect;
class MLIRContext;
/// This class represents a uniqued string owned by an MLIRContext. Strings
/// represented by this type cannot contain nul characters, and may not have a
/// zero length.
///
/// This is a POD type with pointer size, so it should be passed around by
/// value. The underlying data is owned by MLIRContext and is thus immortal for
/// almost all clients.
///
/// An Identifier may be prefixed with a dialect namespace followed by a single
/// dot `.`. This is particularly useful when used as a key in a NamedAttribute
/// to differentiate a dependent attribute (specific to an operation) from a
/// generic attribute defined by the dialect (in general applicable to multiple
/// operations).
class Identifier {
using EntryType =
llvm::StringMapEntry<PointerUnion<Dialect *, MLIRContext *>>;
public:
/// Return an identifier for the specified string.
static Identifier get(const Twine &string, MLIRContext *context);
Identifier(const Identifier &) = default;
Identifier &operator=(const Identifier &other) = default;
/// Return a StringRef for the string.
StringRef strref() const { return entry->first(); }
/// Identifiers implicitly convert to StringRefs.
operator StringRef() const { return strref(); }
/// Return an std::string.
std::string str() const { return strref().str(); }
/// Return a null terminated C string.
const char *c_str() const { return entry->getKeyData(); }
/// Return a pointer to the start of the string data.
const char *data() const { return entry->getKeyData(); }
/// Return the number of bytes in this string.
unsigned size() const { return entry->getKeyLength(); }
/// Return the dialect loaded in the context for this identifier or nullptr if
/// this identifier isn't prefixed with a loaded dialect. For example the
/// `llvm.fastmathflags` identifier would return the LLVM dialect here,
/// assuming it is loaded in the context.
Dialect *getDialect();
/// Return the current MLIRContext associated with this identifier.
MLIRContext *getContext();
const char *begin() const { return data(); }
const char *end() const { return entry->getKeyData() + size(); }
bool operator==(Identifier other) const { return entry == other.entry; }
bool operator!=(Identifier rhs) const { return !(*this == rhs); }
void print(raw_ostream &os) const;
void dump() const;
const void *getAsOpaquePointer() const {
return static_cast<const void *>(entry);
}
static Identifier getFromOpaquePointer(const void *entry) {
return Identifier(static_cast<const EntryType *>(entry));
}
/// Compare the underlying StringRef.
int compare(Identifier rhs) const { return strref().compare(rhs.strref()); }
private:
/// This contains the bytes of the string, which is guaranteed to be nul
/// terminated.
const EntryType *entry;
explicit Identifier(const EntryType *entry) : entry(entry) {}
};
inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) {
identifier.print(os);
return os;
}
// Identifier/Identifier equality comparisons are defined inline.
inline bool operator==(Identifier lhs, StringRef rhs) {
return lhs.strref() == rhs;
}
inline bool operator!=(Identifier lhs, StringRef rhs) { return !(lhs == rhs); }
inline bool operator==(StringRef lhs, Identifier rhs) {
return rhs.strref() == lhs;
}
inline bool operator!=(StringRef lhs, Identifier rhs) { return !(lhs == rhs); }
// Make identifiers hashable.
inline llvm::hash_code hash_value(Identifier arg) {
// Identifiers are uniqued, so we can just hash the pointer they contain.
return llvm::hash_value(arg.getAsOpaquePointer());
}
/// NOTICE: Identifier is deprecated and usages of it should be replaced with
/// StringAttr.
using Identifier = StringAttr;
} // end namespace mlir
namespace llvm {
// Identifiers hash just like pointers, there is no need to hash the bytes.
template <>
struct DenseMapInfo<mlir::Identifier> {
static mlir::Identifier getEmptyKey() {
auto pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
return mlir::Identifier::getFromOpaquePointer(pointer);
}
static mlir::Identifier getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
return mlir::Identifier::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::Identifier val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::Identifier lhs, mlir::Identifier rhs) {
return lhs == rhs;
}
};
/// The pointer inside of an identifier comes from a StringMap, so its alignment
/// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to
/// steal the low bits.
template <>
struct PointerLikeTypeTraits<mlir::Identifier> {
public:
static inline void *getAsVoidPointer(mlir::Identifier i) {
return const_cast<void *>(i.getAsOpaquePointer());
}
static inline mlir::Identifier getFromVoidPointer(void *p) {
return mlir::Identifier::getFromOpaquePointer(p);
}
static constexpr int NumLowBitsAvailable = 2;
};
} // end namespace llvm
#endif

View File

@ -19,7 +19,6 @@
namespace mlir {
class Identifier;
class Location;
class WalkResult;

View File

@ -456,7 +456,7 @@ public:
Dialect *getDialect() const {
if (const auto *abstractOp = getAbstractOperation())
return &abstractOp->dialect;
return representation.get<Identifier>().getDialect();
return representation.get<Identifier>().getReferencedDialect();
}
/// Return the operation name with dialect name stripped, if it has one.

View File

@ -164,8 +164,7 @@ public:
/// Get an instance of the concrete type from a void pointer.
static ConcreteT getFromOpaquePointer(const void *ptr) {
return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>()
: nullptr;
return ConcreteT((const typename BaseT::ImplType *)ptr);
}
protected:

View File

@ -15,8 +15,6 @@
#include "llvm/ADT/StringMap.h"
namespace mlir {
class Identifier;
class Operation;
/// This class allows for representing and managing the symbol table used by
/// operations with the 'SymbolTable' trait. Inserting into and erasing from

View File

@ -27,12 +27,15 @@ class Any;
namespace mlir {
class AnalysisManager;
class Identifier;
class MLIRContext;
class Operation;
class Pass;
class PassInstrumentation;
class PassInstrumentor;
class StringAttr;
// TODO: Remove this when all usages have been replaced with StringAttr.
using Identifier = StringAttr;
namespace detail {
struct OpPassManagerImpl;

View File

@ -105,8 +105,13 @@ public:
/// Copy the provided string into memory managed by our bump pointer
/// allocator.
StringRef copyInto(StringRef str) {
auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
return StringRef(result.data(), str.size());
if (str.empty())
return StringRef();
char *result = allocator.Allocate<char>(str.size() + 1);
std::uninitialized_copy(str.begin(), str.end(), result);
result[str.size()] = 0;
return StringRef(result, str.size());
}
/// Allocate an instance of the provided type.

View File

@ -82,7 +82,7 @@ public:
amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface =
getInterfaceFor(attribute.first.getDialect())) {
getInterfaceFor(attribute.first.getReferencedDialect())) {
return iface->amendOperation(op, attribute, moduleTranslation);
}
return success();

View File

@ -1845,7 +1845,8 @@ public:
mlirOperationGetAttribute(operation->get(), index);
return PyNamedAttribute(
namedAttr.attribute,
std::string(mlirIdentifierStr(namedAttr.name).data));
std::string(mlirIdentifierStr(namedAttr.name).data,
mlirIdentifierStr(namedAttr.name).length));
}
void dunderSetItem(const std::string &name, PyAttribute attr) {
@ -2601,7 +2602,8 @@ void mlir::python::populateIRCore(py::module &m) {
PyPrintAccumulator printAccum;
printAccum.parts.append("NamedAttribute(");
printAccum.parts.append(
mlirIdentifierStr(self.namedAttr.name).data);
py::str(mlirIdentifierStr(self.namedAttr.name).data,
mlirIdentifierStr(self.namedAttr.name).length));
printAccum.parts.append("=");
mlirAttributePrint(self.namedAttr.attribute,
printAccum.getCallback(),

View File

@ -186,11 +186,11 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(str), unwrap(type)));
return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
}
MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {

View File

@ -805,7 +805,7 @@ MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
MlirOperation operation) {
return wrap(unwrap(symbolTable)->insert(unwrap(operation)));
return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
}
void mlirSymbolTableErase(MlirSymbolTable symbolTable,

View File

@ -154,7 +154,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
} else {
auto id = entry.getKey().get<Identifier>();
if (!ids.insert(id).second)
return emitError() << "repeated layout entry key: " << id;
return emitError() << "repeated layout entry key: " << id.getValue();
}
}
return success();
@ -221,7 +221,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
for (const auto &kvp : newEntriesForID) {
Identifier id = kvp.second.getKey().get<Identifier>();
Dialect *dialect = id.getDialect();
Dialect *dialect = id.getReferencedDialect();
if (!entriesForID.count(id)) {
entriesForID[id] = kvp.second;
continue;
@ -377,6 +377,6 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
return success();
}
return op->emitError() << "attribute '" << attr.first
return op->emitError() << "attribute '" << attr.first.getValue()
<< "' not supported by dialect";
}

View File

@ -753,7 +753,7 @@ struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
// Copy over unknown attributes. They might be load bearing for some flow.
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
for (NamedAttribute kv : genericOp->getAttrs()) {
if (!llvm::is_contained(odsAttrs, kv.first.c_str())) {
if (!llvm::is_contained(odsAttrs, kv.first.getValue())) {
newOp->setAttr(kv.first, kv.second);
}
}

View File

@ -46,10 +46,6 @@
using namespace mlir;
using namespace mlir::detail;
void Identifier::print(raw_ostream &os) const { os << str(); }
void Identifier::dump() const { print(llvm::errs()); }
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
void OperationName::dump() const { print(llvm::errs()); }
@ -1339,7 +1335,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
})
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
if (pretty) {
os << loc.getFilename();
os << loc.getFilename().getValue();
} else {
os << "\"";
printEscapedString(loc.getFilename(), os);
@ -1693,7 +1689,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
} else {
os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x"
os << "opaque<" << opaqueAttr.getDialect() << ", \"0x"
<< llvm::toHex(opaqueAttr.getValue()) << "\">";
}

View File

@ -319,6 +319,41 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
ArrayRef<StringRef> data;
};
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
struct StringAttrStorage : public AttributeStorage {
StringAttrStorage(StringRef value, Type type)
: AttributeStorage(type), value(value), referencedDialect(nullptr) {}
/// The hash key is a tuple of the parameter types.
using KeyTy = std::pair<StringRef, Type>;
bool operator==(const KeyTy &key) const {
return value == key.first && getType() == key.second;
}
static ::llvm::hash_code hashKey(const KeyTy &key) {
return DenseMapInfo<KeyTy>::getHashValue(key);
}
/// Define a construction method for creating a new instance of this
/// storage.
static StringAttrStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<StringAttrStorage>())
StringAttrStorage(allocator.copyInto(key.first), key.second);
}
/// Initialize the storage given an MLIRContext.
void initialize(MLIRContext *context);
/// The raw string value.
StringRef value;
/// If the string value contains a dialect namespace prefix (e.g.
/// dialect.blah), this is the dialect referenced.
Dialect *referencedDialect;
};
} // namespace detail
} // namespace mlir

View File

@ -12,28 +12,10 @@
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
// AttributeStorage
//===----------------------------------------------------------------------===//
AttributeStorage::AttributeStorage(Type type)
: type(type.getAsOpaquePointer()) {}
AttributeStorage::AttributeStorage() : type(nullptr) {}
Type AttributeStorage::getType() const {
return Type::getFromOpaquePointer(type);
}
void AttributeStorage::setType(Type newType) {
type = newType.getAsOpaquePointer();
}
//===----------------------------------------------------------------------===//
// Attribute
//===----------------------------------------------------------------------===//
/// Return the type of this attribute.
Type Attribute::getType() const { return impl->getType(); }
/// Return the context this attribute belongs to.
MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }
@ -42,13 +24,8 @@ MLIRContext *Attribute::getContext() const { return getDialect().getContext(); }
//===----------------------------------------------------------------------===//
bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) {
return strcmp(lhs.first.data(), rhs.first.data()) < 0;
return lhs.first.compare(rhs.first) < 0;
}
bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) {
// This is correct even when attr.first.data()[name.size()] is not a zero
// string terminator, because we only care about a less than comparison.
// This can't use memcmp, because it doesn't guarantee that it will stop
// reading both buffers if one is shorter than the other, even if there is
// a difference.
return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0;
return lhs.first.getValue().compare(rhs) < 0;
}

View File

@ -264,6 +264,12 @@ StringAttr StringAttr::get(const Twine &twine, Type type) {
return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
}
StringRef StringAttr::getValue() const { return getImpl()->value; }
Dialect *StringAttr::getReferencedDialect() const {
return getImpl()->referencedDialect;
}
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
@ -1250,7 +1256,7 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
Dialect *dialect = getDialect().getDialect();
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
if (!dialect)
return true;
auto *interface =

View File

@ -253,7 +253,7 @@ static LogicalResult verify(ModuleOp op) {
attr.first.strref()))
return op.emitOpError() << "can only contain attributes with "
"dialect-prefixed names, found: '"
<< attr.first << "'";
<< attr.first.getValue() << "'";
}
// Check that there is at most one data layout spec attribute.
@ -266,7 +266,8 @@ static LogicalResult verify(ModuleOp op) {
op.emitOpError() << "expects at most one data layout attribute";
diag.attachNote() << "'" << layoutSpecAttrName
<< "' is a data layout attribute";
diag.attachNote() << "'" << na.first << "' is a data layout attribute";
diag.attachNote() << "'" << na.first.getValue()
<< "' is a data layout attribute";
}
layoutSpecAttrName = na.first.strref();
layoutSpec = spec;

View File

@ -8,7 +8,6 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
@ -109,11 +108,8 @@ Diagnostic &Diagnostic::operator<<(Twine &&val) {
return *this;
}
/// Stream in an Identifier.
Diagnostic &Diagnostic::operator<<(Identifier val) {
// An identifier is stored in the context, so we don't need to worry about the
// lifetime of its data.
arguments.push_back(DiagnosticArgument(val.strref()));
Diagnostic &Diagnostic::operator<<(StringAttr val) {
arguments.push_back(DiagnosticArgument(val));
return *this;
}
@ -469,7 +465,7 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
// the constructor of SMDiagnostic that takes a location.
std::string locStr;
llvm::raw_string_ostream locOS(locStr);
locOS << fileLoc->getFilename() << ":" << fileLoc->getLine() << ":"
locOS << fileLoc->getFilename().getValue() << ":" << fileLoc->getLine() << ":"
<< fileLoc->getColumn();
llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
diag.print(nullptr, os);

View File

@ -18,7 +18,6 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
@ -33,6 +32,7 @@
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/raw_ostream.h"
@ -227,14 +227,6 @@ public:
/// An action manager for use within the context.
DebugActionManager debugActionManager;
//===--------------------------------------------------------------------===//
// Identifier uniquing
//===--------------------------------------------------------------------===//
// Identifier allocator and mutex for thread safety.
llvm::BumpPtrAllocator identifierAllocator;
llvm::sys::SmartRWMutex<true> identifierMutex;
//===--------------------------------------------------------------------===//
// Diagnostics
//===--------------------------------------------------------------------===//
@ -289,12 +281,6 @@ public:
/// operations.
llvm::StringMap<AbstractOperation> registeredOperations;
/// Identifiers are uniqued by string value and use the internal string set
/// for storage.
llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>,
llvm::BumpPtrAllocator &>
identifiers;
/// An allocator used for AbstractAttribute and AbstractType objects.
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@ -349,10 +335,15 @@ public:
DictionaryAttr emptyDictionaryAttr;
StringAttr emptyStringAttr;
/// Map of string attributes that may reference a dialect, that are awaiting
/// that dialect to be loaded.
llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
DenseMap<StringRef, SmallVector<StringAttrStorage *>>
dialectReferencingStrAttrs;
public:
MLIRContextImpl(bool threadingIsEnabled)
: threadingIsEnabled(threadingIsEnabled),
identifiers(identifierAllocator) {
: threadingIsEnabled(threadingIsEnabled) {
if (threadingIsEnabled) {
ownedThreadPool = std::make_unique<llvm::ThreadPool>();
threadPool = ownedThreadPool.get();
@ -541,12 +532,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
// Refresh all the identifiers dialect field, this catches cases where a
// dialect may be loaded after identifier prefixed with this dialect name
// were already created.
llvm::SmallString<32> dialectPrefix(dialectNamespace);
dialectPrefix.push_back('.');
for (auto &identifierEntry : impl.identifiers)
if (identifierEntry.second.is<MLIRContext *>() &&
identifierEntry.first().startswith(dialectPrefix))
identifierEntry.second = dialect.get();
auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
for (StringAttrStorage *storage : stringAttrsIt->second)
storage->referencedDialect = dialect.get();
impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
}
// Actually register the interfaces with delayed registration.
impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
@ -784,7 +775,8 @@ void AbstractOperation::insert(
MutableArrayRef<Identifier> cachedAttrNames;
if (!attrNames.empty()) {
cachedAttrNames = MutableArrayRef<Identifier>(
impl.identifierAllocator.Allocate<Identifier>(attrNames.size()),
impl.abstractDialectSymbolAllocator.Allocate<Identifier>(
attrNames.size()),
attrNames.size());
for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx));
@ -840,63 +832,6 @@ AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
return it->second;
}
//===----------------------------------------------------------------------===//
// Identifier uniquing
//===----------------------------------------------------------------------===//
/// Return an identifier for the specified string.
Identifier Identifier::get(const Twine &string, MLIRContext *context) {
SmallString<32> tempStr;
StringRef str = string.toStringRef(tempStr);
// Check invariants after seeing if we already have something in the
// identifier table - if we already had it in the table, then it already
// passed invariant checks.
assert(!str.empty() && "Cannot create an empty identifier");
assert(!str.contains('\0') &&
"Cannot create an identifier with a nul character");
auto getDialectOrContext = [&]() {
PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context;
auto dialectNamePair = str.split('.');
if (!dialectNamePair.first.empty())
if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first))
dialectOrContext = dialect;
return dialectOrContext;
};
auto &impl = context->getImpl();
if (!context->isMultithreadingEnabled()) {
auto insertedIt = impl.identifiers.insert({str, nullptr});
if (insertedIt.second)
insertedIt.first->second = getDialectOrContext();
return Identifier(&*insertedIt.first);
}
// Check for an existing identifier in read-only mode.
{
llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.find(str);
if (it != impl.identifiers.end())
return Identifier(&*it);
}
// Acquire a writer-lock so that we can safely create the new instance.
llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.insert({str, getDialectOrContext()}).first;
return Identifier(&*it);
}
Dialect *Identifier::getDialect() {
return entry->second.dyn_cast<Dialect *>();
}
MLIRContext *Identifier::getContext() {
if (Dialect *dialect = getDialect())
return dialect->getContext();
return entry->second.get<MLIRContext *>();
}
//===----------------------------------------------------------------------===//
// Type uniquing
//===----------------------------------------------------------------------===//
@ -995,7 +930,7 @@ StorageUniquer &MLIRContext::getAttributeUniquer() {
void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
MLIRContext *ctx,
TypeID attrID) {
storage->initialize(AbstractAttribute::lookup(attrID, ctx));
storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
// If the attribute did not provide a type, then default to NoneType.
if (!storage->getType())
@ -1019,6 +954,24 @@ DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
return context->getImpl().emptyDictionaryAttr;
}
void StringAttrStorage::initialize(MLIRContext *context) {
// Check for a dialect namespace prefix, if there isn't one we don't need to
// do any additional initialization.
auto dialectNamePair = value.split('.');
if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
return;
// If one exists, we check to see if this dialect is loaded. If it is, we set
// the dialect now, if it isn't we record this storage for initialization
// later if the dialect ever gets loaded.
if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
return;
MLIRContextImpl &impl = context->getImpl();
llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
}
/// Return an empty string.
StringAttr StringAttr::get(MLIRContext *context) {
return context->getImpl().emptyStringAttr;

View File

@ -73,10 +73,10 @@ void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {
void NamedAttrList::push_back(NamedAttribute newAttribute) {
assert(newAttribute.second && "unexpected null attribute");
if (isSorted())
dictionarySorted.setInt(
attrs.empty() ||
strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0);
if (isSorted()) {
dictionarySorted.setInt(attrs.empty() ||
attrs.back().first.compare(newAttribute.first) < 0);
}
dictionarySorted.setPointer(nullptr);
attrs.push_back(newAttribute);
}

View File

@ -170,7 +170,7 @@ LogicalResult OperationVerifier::verifyOperation(
/// Verify that all of the attributes are okay.
for (auto attr : op.getAttrs()) {
// Check for any optional dialect specific attributes.
if (auto *dialect = attr.first.getDialect())
if (auto *dialect = attr.first.getReferencedDialect())
if (failed(dialect->verifyOperationAttribute(&op, attr)))
return failure();
}

View File

@ -431,7 +431,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
for (const auto &kvp : ids) {
Identifier identifier = kvp.second.getKey().get<Identifier>();
Dialect *dialect = identifier.getDialect();
Dialect *dialect = identifier.getReferencedDialect();
// Ignore attributes that belong to an unknown dialect, the dialect may
// actually implement the relevant interface but we don't know about that.

View File

@ -273,7 +273,7 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return emitError("expected attribute name");
if (!seenKeys.insert(*nameId).second)
return emitError("duplicate key '")
<< *nameId << "' in dictionary attribute";
<< nameId->getValue() << "' in dictionary attribute";
consumeToken();
// Lazy load a dialect in the context if there is a possible namespace.

View File

@ -1127,7 +1127,7 @@ public:
Optional<NamedAttribute> duplicate = opState.attributes.findDuplicate();
if (duplicate)
return emitError(getNameLoc(), "attribute '")
<< duplicate->first
<< duplicate->first.getValue()
<< "' occurs more than once in the attribute list";
return success();
}

View File

@ -822,7 +822,7 @@ CppEmitter::emitOperandsAndAttributes(Operation &op,
auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
if (llvm::is_contained(exclude, attr.first.strref()))
return success();
os << "/* " << attr.first << " */";
os << "/* " << attr.first.getValue() << " */";
if (failed(emitAttribute(op.getLoc(), attr.second)))
return failure();
return success();

View File

@ -221,7 +221,7 @@ private:
if (printAttrs) {
os << "\n";
for (const NamedAttribute &attr : op->getAttrs()) {
os << '\n' << attr.first << ": ";
os << '\n' << attr.first.getValue() << ": ";
emitMlirAttr(os, attr.second);
}
}

View File

@ -494,7 +494,7 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
LinalgTilingOptions()
.setTileSizes({8, 8, 4})
.setLoopType(LinalgTilingLoopType::Loops)
.setDistributionOptions(cyclicNprocsEqNiters),
.setDistributionOptions(cyclicNprocsEqNiters),
LinalgTransformationFilter(
Identifier::get("tensors_distribute1", context),
Identifier::get("tensors_after_distribute1", context)));
@ -508,8 +508,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
SmallVector<RewritePatternSet, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
stage1Patterns.emplace_back(
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
@ -519,8 +518,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
.setInterchange({1, 2, 0}),
LinalgTransformationFilter(Identifier::get("START", ctx),
Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
}
{
// Canonicalization patterns

View File

@ -243,7 +243,7 @@ void TestDerivedAttributeDriver::runOnFunction() {
if (!dAttr)
return;
for (auto d : dAttr)
dOp.emitRemark() << d.first << " = " << d.second;
dOp.emitRemark() << d.first.getValue() << " = " << d.second;
});
}

View File

@ -37,8 +37,8 @@ struct TestPrintNestingPass
if (!op->getAttrs().empty()) {
printIndent() << op->getAttrs().size() << " attributes:\n";
for (NamedAttribute attr : op->getAttrs())
printIndent() << " - '" << attr.first << "' : '" << attr.second
<< "'\n";
printIndent() << " - '" << attr.first.getValue() << "' : '"
<< attr.second << "'\n";
}
// Recurse into each of the regions attached to the operation.