[mlir] Remove the use of "kinds" from Attributes and Types

This greatly simplifies a large portion of the underlying infrastructure, allows for lookups of singleton classes to be much more efficient and always thread-safe(no locking). As a result of this, the dialect symbol registry has been removed as it is no longer necessary.

For users broken by this change, an alert was sent out(https://llvm.discourse.group/t/removing-kinds-from-attributes-and-types) that helps prevent a majority of the breakage surface area. All that should be necessary, if the advice in that alert was followed, is removing the kind passed to the ::get methods.

Differential Revision: https://reviews.llvm.org/D86121
This commit is contained in:
River Riddle 2020-08-18 15:59:53 -07:00
parent a7d0b7a786
commit 250f43d3ec
42 changed files with 596 additions and 897 deletions

View File

@ -25,17 +25,6 @@ struct RealAttributeStorage;
struct TypeAttributeStorage;
} // namespace detail
enum AttributeKind {
FIR_ATTR = mlir::Attribute::FIRST_FIR_ATTR,
FIR_EXACTTYPE, // instance_of, precise type relation
FIR_SUBCLASS, // subsumed_by, is-a (subclass) relation
FIR_POINT,
FIR_CLOSEDCLOSED_INTERVAL,
FIR_OPENCLOSED_INTERVAL,
FIR_CLOSEDOPEN_INTERVAL,
FIR_REAL_ATTR,
};
class ExactTypeAttr
: public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
detail::TypeAttributeStorage> {
@ -47,8 +36,6 @@ public:
static ExactTypeAttr get(mlir::Type value);
mlir::Type getType() const;
static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; }
};
class SubclassAttr
@ -62,8 +49,6 @@ public:
static SubclassAttr get(mlir::Type value);
mlir::Type getType() const;
static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; }
};
// Attributes for building SELECT CASE multiway branches
@ -80,9 +65,6 @@ public:
static constexpr llvm::StringRef getAttrName() { return "interval"; }
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
static constexpr unsigned getId() {
return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL;
}
};
/// An upper bound is an open interval (including the bound value) as given as
@ -97,9 +79,6 @@ public:
static constexpr llvm::StringRef getAttrName() { return "upper"; }
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
static constexpr unsigned getId() {
return AttributeKind::FIR_OPENCLOSED_INTERVAL;
}
};
/// A lower bound is an open interval (including the bound value) as given as
@ -114,9 +93,6 @@ public:
static constexpr llvm::StringRef getAttrName() { return "lower"; }
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
static constexpr unsigned getId() {
return AttributeKind::FIR_CLOSEDOPEN_INTERVAL;
}
};
/// A pointer interval is a closed interval as given as an ssa-value. The
@ -131,7 +107,6 @@ public:
static constexpr llvm::StringRef getAttrName() { return "point"; }
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
static constexpr unsigned getId() { return AttributeKind::FIR_POINT; }
};
/// A real attribute is used to workaround MLIR's default parsing of a real
@ -150,8 +125,6 @@ public:
int getFKind() const;
llvm::APFloat getValue() const;
static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; }
};
mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,

View File

@ -54,29 +54,6 @@ struct SequenceTypeStorage;
struct TypeDescTypeStorage;
} // namespace detail
/// Integral identifier for all the types comprising the FIR type system
enum TypeKind {
// The enum starts at the range reserved for this dialect.
FIR_TYPE = mlir::Type::FIRST_FIR_TYPE,
FIR_BOX, // (static) descriptor
FIR_BOXCHAR, // CHARACTER pointer and length
FIR_BOXPROC, // procedure with host association
FIR_CHARACTER, // intrinsic type
FIR_COMPLEX, // intrinsic type
FIR_DERIVED, // derived
FIR_DIMS,
FIR_FIELD,
FIR_HEAP,
FIR_INT, // intrinsic type
FIR_LEN,
FIR_LOGICAL, // intrinsic type
FIR_POINTER, // POINTER attr
FIR_REAL, // intrinsic type
FIR_REFERENCE,
FIR_SEQUENCE, // DIMENSION attr
FIR_TYPEDESC,
};
// These isa_ routines follow the precedent of llvm::isa_or_null<>
/// Is `t` any of the FIR dialect types?
@ -111,12 +88,6 @@ bool isa_aggregate(mlir::Type t);
/// not a memory reference type, then returns a null `Type`.
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
/// Boilerplate mixin template
template <typename A, unsigned Id>
struct IntrinsicTypeMixin {
static constexpr unsigned getId() { return Id; }
};
// Intrinsic types
/// Model of the Fortran CHARACTER intrinsic type, including the KIND type
@ -124,8 +95,7 @@ struct IntrinsicTypeMixin {
/// is thus the type of a single character value.
class CharacterType
: public mlir::Type::TypeBase<CharacterType, mlir::Type,
detail::CharacterTypeStorage>,
public IntrinsicTypeMixin<CharacterType, TypeKind::FIR_CHARACTER> {
detail::CharacterTypeStorage> {
public:
using Base::Base;
static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind);
@ -136,8 +106,7 @@ public:
/// parameter. COMPLEX is a floating point type with a real and imaginary
/// member.
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
detail::CplxTypeStorage>,
public IntrinsicTypeMixin<CplxType, TypeKind::FIR_COMPLEX> {
detail::CplxTypeStorage> {
public:
using Base::Base;
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
@ -151,8 +120,7 @@ public:
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
/// parameter.
class IntType
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage>,
public IntrinsicTypeMixin<IntType, TypeKind::FIR_INT> {
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
public:
using Base::Base;
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
@ -163,8 +131,7 @@ public:
/// parameter.
class LogicalType
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
detail::LogicalTypeStorage>,
public IntrinsicTypeMixin<LogicalType, TypeKind::FIR_LOGICAL> {
detail::LogicalTypeStorage> {
public:
using Base::Base;
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
@ -174,8 +141,7 @@ public:
/// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
/// KIND type parameter.
class RealType : public mlir::Type::TypeBase<RealType, mlir::Type,
detail::RealTypeStorage>,
public IntrinsicTypeMixin<RealType, TypeKind::FIR_REAL> {
detail::RealTypeStorage> {
public:
using Base::Base;
static RealType get(mlir::MLIRContext *ctxt, KindTy kind);
@ -400,7 +366,6 @@ public:
static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
void finalize(llvm::ArrayRef<TypePair> lenPList,
llvm::ArrayRef<TypePair> typeList);
static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; }
detail::RecordTypeStorage const *uniqueKey() const;

View File

@ -74,13 +74,13 @@ private:
} // namespace detail
ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
return Base::get(value.getContext(), FIR_EXACTTYPE, value);
return Base::get(value.getContext(), value);
}
mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
SubclassAttr SubclassAttr::get(mlir::Type value) {
return Base::get(value.getContext(), FIR_SUBCLASS, value);
return Base::get(value.getContext(), value);
}
mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
@ -88,26 +88,26 @@ mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
using AttributeUniquer = mlir::detail::AttributeUniquer;
ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt, getId());
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
}
UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<UpperBoundAttr>(ctxt, getId());
return AttributeUniquer::get<UpperBoundAttr>(ctxt);
}
LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<LowerBoundAttr>(ctxt, getId());
return AttributeUniquer::get<LowerBoundAttr>(ctxt);
}
PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
return AttributeUniquer::get<PointIntervalAttr>(ctxt, getId());
return AttributeUniquer::get<PointIntervalAttr>(ctxt);
}
// RealAttr
RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
const RealAttr::ValueType &key) {
return Base::get(ctxt, getId(), key);
return Base::get(ctxt, key);
}
int RealAttr::getFKind() const { return getImpl()->getFKind(); }

View File

@ -824,13 +824,11 @@ bool inbounds(A v, B lb, B ub) {
}
bool isa_fir_type(mlir::Type t) {
return inbounds(t.getKind(), mlir::Type::FIRST_FIR_TYPE,
mlir::Type::LAST_FIR_TYPE);
return llvm::isa<FIROpsDialect>(t.getDialect());
}
bool isa_std_type(mlir::Type t) {
return inbounds(t.getKind(), mlir::Type::FIRST_STANDARD_TYPE,
mlir::Type::LAST_STANDARD_TYPE);
return t.getDialect().getNamespace().empty();
}
bool isa_fir_or_std_type(mlir::Type t) {
@ -868,7 +866,7 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
// CHARACTER
CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) {
return Base::get(ctxt, FIR_CHARACTER, kind);
return Base::get(ctxt, kind);
}
int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
@ -876,7 +874,7 @@ int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
// Dims
DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) {
return Base::get(ctxt, FIR_DIMS, rank);
return Base::get(ctxt, rank);
}
unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
@ -884,19 +882,19 @@ unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
// Field
FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) {
return Base::get(ctxt, FIR_FIELD, 0);
return Base::get(ctxt, 0);
}
// Len
LenType fir::LenType::get(mlir::MLIRContext *ctxt) {
return Base::get(ctxt, FIR_LEN, 0);
return Base::get(ctxt, 0);
}
// LOGICAL
LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) {
return Base::get(ctxt, FIR_LOGICAL, kind);
return Base::get(ctxt, kind);
}
int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
@ -904,7 +902,7 @@ int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
// INTEGER
IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) {
return Base::get(ctxt, FIR_INT, kind);
return Base::get(ctxt, kind);
}
int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
@ -912,7 +910,7 @@ int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
// COMPLEX
CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) {
return Base::get(ctxt, FIR_COMPLEX, kind);
return Base::get(ctxt, kind);
}
mlir::Type fir::CplxType::getElementType() const {
@ -924,7 +922,7 @@ KindTy fir::CplxType::getFKind() const { return getImpl()->getFKind(); }
// REAL
RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
return Base::get(ctxt, FIR_REAL, kind);
return Base::get(ctxt, kind);
}
int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
@ -932,7 +930,7 @@ int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
// Box<T>
BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) {
return Base::get(elementType.getContext(), FIR_BOX, elementType, map);
return Base::get(elementType.getContext(), elementType, map);
}
mlir::Type fir::BoxType::getEleTy() const {
@ -953,7 +951,7 @@ fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
// BoxChar<C>
BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) {
return Base::get(ctxt, FIR_BOXCHAR, kind);
return Base::get(ctxt, kind);
}
CharacterType fir::BoxCharType::getEleTy() const {
@ -963,7 +961,7 @@ CharacterType fir::BoxCharType::getEleTy() const {
// BoxProc<T>
BoxProcType fir::BoxProcType::get(mlir::Type elementType) {
return Base::get(elementType.getContext(), FIR_BOXPROC, elementType);
return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::BoxProcType::getEleTy() const {
@ -984,7 +982,7 @@ fir::BoxProcType::verifyConstructionInvariants(mlir::Location loc,
// Reference<T>
ReferenceType fir::ReferenceType::get(mlir::Type elementType) {
return Base::get(elementType.getContext(), FIR_REFERENCE, elementType);
return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::ReferenceType::getEleTy() const {
@ -1005,7 +1003,7 @@ fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
PointerType fir::PointerType::get(mlir::Type elementType) {
assert(singleIndirectionLevel(elementType) && "invalid element type");
return Base::get(elementType.getContext(), FIR_POINTER, elementType);
return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::PointerType::getEleTy() const {
@ -1033,7 +1031,7 @@ fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
HeapType fir::HeapType::get(mlir::Type elementType) {
assert(singleIndirectionLevel(elementType) && "invalid element type");
return Base::get(elementType.getContext(), FIR_HEAP, elementType);
return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::HeapType::getEleTy() const {
@ -1054,7 +1052,7 @@ fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType,
mlir::AffineMapAttr map) {
auto *ctxt = elementType.getContext();
return Base::get(ctxt, FIR_SEQUENCE, shape, elementType, map);
return Base::get(ctxt, shape, elementType, map);
}
mlir::Type fir::SequenceType::getEleTy() const {
@ -1136,7 +1134,7 @@ llvm::hash_code fir::hash_value(const SequenceType::Shape &sh) {
/// This type captures a Fortran "derived type"
RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) {
return Base::get(ctxt, FIR_DERIVED, name);
return Base::get(ctxt, name);
}
void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
@ -1179,7 +1177,7 @@ mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
assert(!ofType.isa<ReferenceType>());
return Base::get(ofType.getContext(), FIR_TYPEDESC, ofType);
return Base::get(ofType.getContext(), ofType);
}
mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
@ -1222,9 +1220,7 @@ void fir::verifyIntegralType(mlir::Type type) {
void fir::printFirType(FIROpsDialect *, mlir::Type ty,
mlir::DialectAsmPrinter &p) {
auto &os = p.getStream();
switch (ty.getKind()) {
case fir::FIR_BOX: {
auto type = ty.cast<BoxType>();
if (auto type = ty.dyn_cast<BoxType>()) {
os << "box<";
p.printType(type.getEleTy());
if (auto map = type.getLayoutMap()) {
@ -1232,24 +1228,28 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
p.printAttribute(map);
}
os << '>';
} break;
case fir::FIR_BOXCHAR: {
auto type = ty.cast<BoxCharType>().getEleTy();
os << "boxchar<" << type.cast<fir::CharacterType>().getFKind() << '>';
} break;
case fir::FIR_BOXPROC:
return;
}
if (auto type = ty.dyn_cast<BoxCharType>()) {
os << "boxchar<" << type.getEleTy().cast<fir::CharacterType>().getFKind()
<< '>';
return;
}
if (auto type = ty.dyn_cast<BoxProcType>()) {
os << "boxproc<";
p.printType(ty.cast<BoxProcType>().getEleTy());
p.printType(type.getEleTy());
os << '>';
break;
case fir::FIR_CHARACTER: // intrinsic
os << "char<" << ty.cast<CharacterType>().getFKind() << '>';
break;
case fir::FIR_COMPLEX: // intrinsic
os << "complex<" << ty.cast<CplxType>().getFKind() << '>';
break;
case fir::FIR_DERIVED: { // derived
auto type = ty.cast<fir::RecordType>();
return;
}
if (auto type = ty.dyn_cast<CharacterType>()) {
os << "char<" << type.getFKind() << '>';
return;
}
if (auto type = ty.dyn_cast<CplxType>()) {
os << "complex<" << type.getFKind() << '>';
return;
}
if (auto type = ty.dyn_cast<RecordType>()) {
os << "type<" << type.getName();
if (!recordTypeVisited.count(type.uniqueKey())) {
recordTypeVisited.insert(type.uniqueKey());
@ -1274,43 +1274,52 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
recordTypeVisited.erase(type.uniqueKey());
}
os << '>';
} break;
case fir::FIR_DIMS:
os << "dims<" << ty.cast<DimsType>().getRank() << '>';
break;
case fir::FIR_FIELD:
return;
}
if (auto type = ty.dyn_cast<DimsType>()) {
os << "dims<" << type.getRank() << '>';
return;
}
if (ty.isa<FieldType>()) {
os << "field";
break;
case fir::FIR_HEAP:
return;
}
if (auto type = ty.dyn_cast<HeapType>()) {
os << "heap<";
p.printType(ty.cast<HeapType>().getEleTy());
p.printType(type.getEleTy());
os << '>';
break;
case fir::FIR_INT: // intrinsic
os << "int<" << ty.cast<fir::IntType>().getFKind() << '>';
break;
case fir::FIR_LEN:
return;
}
if (auto type = ty.dyn_cast<fir::IntType>()) {
os << "int<" << type.getFKind() << '>';
return;
}
if (auto type = ty.dyn_cast<LenType>()) {
os << "len";
break;
case fir::FIR_LOGICAL: // intrinsic
os << "logical<" << ty.cast<LogicalType>().getFKind() << '>';
break;
case fir::FIR_POINTER:
return;
}
if (auto type = ty.dyn_cast<LogicalType>()) {
os << "logical<" << type.getFKind() << '>';
return;
}
if (auto type = ty.dyn_cast<PointerType>()) {
os << "ptr<";
p.printType(ty.cast<PointerType>().getEleTy());
p.printType(type.getEleTy());
os << '>';
break;
case fir::FIR_REAL: // intrinsic
os << "real<" << ty.cast<fir::RealType>().getFKind() << '>';
break;
case fir::FIR_REFERENCE:
return;
}
if (auto type = ty.dyn_cast<fir::RealType>()) {
os << "real<" << type.getFKind() << '>';
return;
}
if (auto type = ty.dyn_cast<ReferenceType>()) {
os << "ref<";
p.printType(ty.cast<ReferenceType>().getEleTy());
p.printType(type.getEleTy());
os << '>';
break;
case fir::FIR_SEQUENCE: {
return;
}
if (auto type = ty.dyn_cast<SequenceType>()) {
os << "array";
auto type = ty.cast<SequenceType>();
auto shape = type.getShape();
if (shape.size()) {
printBounds(os, shape);
@ -1323,11 +1332,12 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
map.print(os);
}
os << '>';
} break;
case fir::FIR_TYPEDESC:
return;
}
if (auto type = ty.dyn_cast<TypeDescType>()) {
os << "tdesc<";
p.printType(ty.cast<TypeDescType>().getOfTy());
p.printType(type.getOfTy());
os << '>';
break;
return;
}
}

View File

@ -190,11 +190,10 @@ public:
assert(!elementTypes.empty() && "expected at least 1 element type");
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
// of this type. The first two parameters are the context to unique in and
// the kind of the type. The parameters after the type kind are forwarded to
// the storage instance.
// of this type. The first parameter is the context to unique in. The
// parameters after the type kind are forwarded to the storage instance.
mlir::MLIRContext *ctx = elementTypes.front().getContext();
return Base::get(ctx, ToyTypes::Struct, elementTypes);
return Base::get(ctx, elementTypes);
}
/// Returns the element types of this struct type.

View File

@ -63,13 +63,6 @@ public:
// Toy Types
//===----------------------------------------------------------------------===//
/// Create a local enumeration with all of the types that are defined by Toy.
namespace ToyTypes {
enum Types {
Struct = mlir::Type::FIRST_TOY_TYPE,
};
} // end namespace ToyTypes
/// This class defines the Toy struct type. It represents a collection of
/// element types. All derived types in MLIR must inherit from the CRTP class
/// 'Type::TypeBase'. It takes as template parameters the concrete type

View File

@ -474,11 +474,10 @@ StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
assert(!elementTypes.empty() && "expected at least 1 element type");
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
// of this type. The first two parameters are the context to unique in and the
// kind of the type. The parameters after the type kind are forwarded to the
// storage instance.
// of this type. The first parameter is the context to unique in. The
// parameters after the type kind are forwarded to the storage instance.
mlir::MLIRContext *ctx = elementTypes.front().getContext();
return Base::get(ctx, ToyTypes::Struct, elementTypes);
return Base::get(ctx, elementTypes);
}
/// Returns the element types of this struct type.

View File

@ -64,34 +64,6 @@ class LLVMIntegerType;
/// structs, the entire type is the identifier) and are thread-safe.
class LLVMType : public Type {
public:
enum Kind {
// Keep non-parametric types contiguous in the enum.
VoidType = FIRST_LLVM_TYPE + 1,
HalfType,
BFloatType,
FloatType,
DoubleType,
FP128Type,
X86FP80Type,
PPCFP128Type,
X86MMXType,
LabelType,
TokenType,
MetadataType,
// End of non-parametric types.
FunctionType,
IntegerType,
PointerType,
FixedVectorType,
ScalableVectorType,
ArrayType,
StructType,
FIRST_NEW_LLVM_TYPE = VoidType,
LAST_NEW_LLVM_TYPE = StructType,
FIRST_TRIVIAL_TYPE = VoidType,
LAST_TRIVIAL_TYPE = MetadataType
};
/// Inherit base constructors.
using Type::Type;
@ -256,27 +228,24 @@ public:
//===----------------------------------------------------------------------===//
// Batch-define trivial types.
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
public: \
using Base::Base; \
static ClassName get(MLIRContext *context) { \
return Base::get(context, Kind); \
} \
}
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMType::VoidType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMType::HalfType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMType::BFloatType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMType::FloatType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMType::DoubleType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMType::FP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMType::X86FP80Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMType::PPCFP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMType::X86MMXType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMType::TokenType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMType::LabelType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMType::MetadataType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
#undef DEFINE_TRIVIAL_LLVM_TYPE

View File

@ -16,11 +16,6 @@ namespace mlir {
class MLIRContext;
namespace linalg {
enum LinalgTypes {
Range = Type::FIRST_LINALG_TYPE,
LAST_USED_LINALG_TYPE = Range,
};
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
/// A RangeType represents a minimal range abstraction (min, max, step).
@ -36,11 +31,6 @@ class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
/// Construction hook.
static RangeType get(MLIRContext *context) {
/// Custom, uniq'ed construction in the MLIRContext.
return Base::get(context, LinalgTypes::Range);
}
};
} // namespace linalg

View File

@ -31,15 +31,6 @@ struct UniformQuantizedPerAxisTypeStorage;
} // namespace detail
namespace QuantizationTypes {
enum Kind {
Any = Type::FIRST_QUANTIZATION_TYPE,
UniformQuantized,
UniformQuantizedPerAxis,
LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
};
} // namespace QuantizationTypes
/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {

View File

@ -32,15 +32,6 @@ struct TargetEnvAttributeStorage;
struct VerCapExtAttributeStorage;
} // namespace detail
/// SPIR-V dialect-specific attribute kinds.
namespace AttrKind {
enum Kind {
InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
TargetEnv, /// Target environment
VerCapExt, /// (version, extension, capability) triple
};
} // namespace AttrKind
/// An attribute that specifies the information regarding the interface
/// variable: descriptor set, binding, storage class.
class InterfaceVarABIAttr

View File

@ -65,19 +65,6 @@ struct StructTypeStorage;
} // namespace detail
namespace TypeKind {
enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
CooperativeMatrix,
Image,
Matrix,
Pointer,
RuntimeArray,
Struct,
LAST_SPIRV_TYPE = Struct,
};
}
// Base SPIR-V type for providing availability queries.
class SPIRVType : public Type {
public:

View File

@ -29,56 +29,28 @@ namespace shape {
/// Alias type for extent tensors.
RankedTensorType getExtentTensorType(MLIRContext *ctx);
namespace ShapeTypes {
enum Kind {
Component = Type::FIRST_SHAPE_TYPE,
Element,
Shape,
Size,
ValueShape,
Witness,
LAST_SHAPE_TYPE = Witness
};
} // namespace ShapeTypes
/// The component type corresponding to shape, element type and attribute.
class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
public:
using Base::Base;
static ComponentType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Component);
}
};
/// The element type of the shaped type.
class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
public:
using Base::Base;
static ElementType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Element);
}
};
/// The shape descriptor type represents rank and dimension sizes.
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
public:
using Base::Base;
static ShapeType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Shape);
}
};
/// The type of a single dimension.
class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
public:
using Base::Base;
static SizeType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Size);
}
};
/// The ValueShape represents a (potentially unknown) runtime value and shape.
@ -86,10 +58,6 @@ class ValueShapeType
: public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
public:
using Base::Base;
static ValueShapeType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::ValueShape);
}
};
/// The Witness represents a runtime constraint, to be used as shape related
@ -97,10 +65,6 @@ public:
class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
public:
using Base::Base;
static WitnessType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Witness);
}
};
#define GET_OP_CLASSES

View File

@ -137,15 +137,23 @@ namespace detail {
// MLIRContext. This class manages all creation and uniquing of attributes.
class AttributeUniquer {
public:
/// Get an uniqued instance of attribute T.
/// Get an uniqued instance of a parametric attribute T.
template <typename T, typename... Args>
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>
get(MLIRContext *ctx, Args &&...args) {
return ctx->getAttributeUniquer().get<typename T::ImplType>(
T::getTypeID(),
[ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
},
kind, std::forward<Args>(args)...);
T::getTypeID(), std::forward<Args>(args)...);
}
/// Get an uniqued instance of a singleton attribute T.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, AttributeStorage>::value, T>
get(MLIRContext *ctx) {
return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
}
template <typename T, typename... Args>
@ -156,6 +164,26 @@ public:
std::forward<Args>(args)...);
}
/// Register a parametric attribute instance T with the uniquer.
template <typename T>
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, AttributeStorage>::value>
registerAttribute(MLIRContext *ctx) {
ctx->getAttributeUniquer()
.registerParametricStorageType<typename T::ImplType>(T::getTypeID());
}
/// Register a singleton attribute instance T with the uniquer.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, AttributeStorage>::value>
registerAttribute(MLIRContext *ctx) {
ctx->getAttributeUniquer()
.registerSingletonStorageType<typename T::ImplType>(
T::getTypeID(), [ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
});
}
private:
/// Initialize the given attribute storage instance.
static void initializeAttributeStorage(AttributeStorage *storage,

View File

@ -54,14 +54,6 @@ struct SparseElementsAttributeStorage;
/// passed by value.
class Attribute {
public:
/// Integer identifier for all the concrete attribute kinds.
enum Kind {
// Reserve attribute kinds for dialect specific extensions.
#define DEFINE_SYM_KIND_RANGE(Dialect) \
FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
#include "DialectSymbolRegistry.def"
};
/// Utility class for implementing attributes.
template <typename ConcreteType, typename BaseType, typename StorageType,
template <typename T> class... Traits>
@ -94,9 +86,6 @@ public:
// Support dyn_cast'ing Attribute to itself.
static bool classof(Attribute) { return true; }
/// Return the classification for this attribute.
unsigned getKind() const { return impl->getKind(); }
/// Return a unique identifier for the concrete attribute type. This is used
/// to support dynamic type casting.
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
@ -173,54 +162,6 @@ private:
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
// StandardAttributes
//===----------------------------------------------------------------------===//
namespace StandardAttributes {
enum Kind {
AffineMap = Attribute::FIRST_STANDARD_ATTR,
Array,
Dictionary,
Float,
Integer,
IntegerSet,
Opaque,
String,
SymbolRef,
Type,
Unit,
/// Elements Attributes.
DenseIntOrFPElements,
DenseStringElements,
OpaqueElements,
SparseElements,
FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
LAST_ELEMENTS_ATTR = SparseElements,
/// Locations.
CallSiteLocation,
FileLineColLocation,
FusedLocation,
NameLocation,
OpaqueLocation,
UnknownLocation,
// Represents a location as a 'void*' pointer to a front-end's opaque
// location information, which must live longer than the MLIR objects that
// refer to it. OpaqueLocation's are never serialized.
//
// TODO: OpaqueLocation,
// Represents a value inlined through a function call.
// TODO: InlinedLocation,
FIRST_LOCATION_ATTR = CallSiteLocation,
LAST_LOCATION_ATTR = UnknownLocation,
};
} // namespace StandardAttributes
//===----------------------------------------------------------------------===//
// AffineMapAttr
//===----------------------------------------------------------------------===//

View File

@ -154,21 +154,15 @@ protected:
void addOperation(AbstractOperation opInfo);
/// This method is used by derived classes to add their types to the set.
/// Register a set of type classes with this dialect.
template <typename... Args> void addTypes() {
(void)std::initializer_list<int>{
0, (addType(Args::getTypeID(), AbstractType::get<Args>(*this)), 0)...};
(void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
}
void addType(TypeID typeID, AbstractType &&typeInfo);
/// This method is used by derived classes to add their attributes to the set.
/// Register a set of attribute classes with this dialect.
template <typename... Args> void addAttributes() {
(void)std::initializer_list<int>{
0,
(addAttribute(Args::getTypeID(), AbstractAttribute::get<Args>(*this)),
0)...};
(void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
}
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
/// Enable support for unregistered operations.
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
@ -189,6 +183,22 @@ private:
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
/// Register an attribute instance with this dialect.
template <typename T> void addAttribute() {
// Add this attribute to the dialect and register it with the uniquer.
addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
detail::AttributeUniquer::registerAttribute<T>(context);
}
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
/// Register a type instance with this dialect.
template <typename T> void addType() {
// Add this type to the dialect and register it with the uniquer.
addType(T::getTypeID(), AbstractType::get<T>(*this));
detail::TypeUniquer::registerType<T>(context);
}
void addType(TypeID typeID, AbstractType &&typeInfo);
/// The namespace of this dialect.
StringRef name;

View File

@ -1,44 +0,0 @@
//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file enumerates the different dialects that define custom classes
// within the attribute or type system.
//
//===----------------------------------------------------------------------===//
DEFINE_SYM_KIND_RANGE(STANDARD)
DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
DEFINE_SYM_KIND_RANGE(TENSORFLOW)
DEFINE_SYM_KIND_RANGE(LLVM)
DEFINE_SYM_KIND_RANGE(QUANTIZATION)
DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect
DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect
DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect
// The following ranges are reserved for experimenting with MLIR dialects in a
// private context without having to register them here.
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
#undef DEFINE_SYM_KIND_RANGE

View File

@ -756,7 +756,7 @@ public:
/// all attributes of the given kind in the form : <alias>[0-9]+. These
/// aliases must not contain `.`.
virtual void getAttributeKindAliases(
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
SmallVectorImpl<std::pair<TypeID, StringRef>> &aliases) const {}
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
virtual void getAttributeAliases(

View File

@ -38,33 +38,6 @@ struct TupleTypeStorage;
} // namespace detail
namespace StandardTypes {
enum Kind {
// Floating point.
BF16 = Type::Kind::FIRST_STANDARD_TYPE,
F16,
F32,
F64,
FIRST_FLOATING_POINT_TYPE = BF16,
LAST_FLOATING_POINT_TYPE = F64,
// Target pointer sized integer, used (e.g.) in affine mappings.
Index,
// Derived types.
Integer,
Vector,
RankedTensor,
UnrankedTensor,
MemRef,
UnrankedMemRef,
Complex,
Tuple,
None,
};
} // namespace StandardTypes
//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

View File

@ -82,29 +82,29 @@ public:
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
}
protected:
/// Get or create a new ConcreteT instance within the ctx. This
/// function is guaranteed to return a non null object and will assert if
/// the arguments provided are invalid.
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
static ConcreteT get(MLIRContext *ctx, Args... args) {
// Ensure that the invariants are correct for construction.
assert(succeeded(ConcreteT::verifyConstructionInvariants(
generateUnknownStorageLocation(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, kind, args...);
return UniquerT::template get<ConcreteT>(ctx, args...);
}
/// Get or create a new ConcreteT instance within the ctx, defined at
/// the given, potentially unknown, location. If the arguments provided are
/// invalid then emit errors and return a null object.
template <typename LocationT, typename... Args>
static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
static ConcreteT getChecked(LocationT loc, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
}
protected:
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
template <typename... Args> LogicalResult mutate(Args &&...args) {

View File

@ -121,15 +121,23 @@ namespace detail {
/// A utility class to get, or create, unique instances of types within an
/// MLIRContext. This class manages all creation and uniquing of types.
struct TypeUniquer {
/// Get an uniqued instance of a type T.
/// Get an uniqued instance of a parametric type T.
template <typename T, typename... Args>
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, TypeStorage>::value, T>
get(MLIRContext *ctx, Args &&...args) {
return ctx->getTypeUniquer().get<typename T::ImplType>(
T::getTypeID(),
[&](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
},
kind, std::forward<Args>(args)...);
T::getTypeID(), std::forward<Args>(args)...);
}
/// Get an uniqued instance of a singleton type T.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, TypeStorage>::value, T>
get(MLIRContext *ctx) {
return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
}
/// Change the mutable component of the given type instance in the provided
@ -141,6 +149,25 @@ struct TypeUniquer {
return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
std::forward<Args>(args)...);
}
/// Register a parametric type instance T with the uniquer.
template <typename T>
static typename std::enable_if_t<
!std::is_same<typename T::ImplType, TypeStorage>::value>
registerType(MLIRContext *ctx) {
ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
T::getTypeID());
}
/// Register a singleton type instance T with the uniquer.
template <typename T>
static typename std::enable_if_t<
std::is_same<typename T::ImplType, TypeStorage>::value>
registerType(MLIRContext *ctx) {
ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
T::getTypeID(), [&](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
});
}
};
} // namespace detail

View File

@ -34,11 +34,11 @@ struct OpaqueTypeStorage;
///
/// Some types are "primitives" meaning they do not have any parameters, for
/// example the Index type. Parametric types have additional information that
/// differentiates the types of the same kind between them, for example the
/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
/// different instances of the IntegerType. Type parameters are part of the
/// unique immutable key. The mutable component of the type can be modified
/// after the type is created, but cannot affect the identity of the type.
/// differentiates the types of the same class, for example the Integer type has
/// bitwidth, making i8 and i16 belong to the same kind by be different
/// instances of the IntegerType. Type parameters are part of the unique
/// immutable key. The mutable component of the type can be modified after the
/// type is created, but cannot affect the identity of the type.
///
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
///
@ -53,20 +53,19 @@ struct OpaqueTypeStorage;
/// * This method is expected to return failure if a type cannot be
/// constructed with 'args', success otherwise.
/// * 'args' must correspond with the arguments passed into the
/// 'TypeBase::get' call after the type kind.
/// 'TypeBase::get' call.
///
///
/// Type storage objects inherit from TypeStorage and contain the following:
/// - The type kind (for LLVM-style RTTI).
/// - The dialect that defined the type.
/// - Any parameters of the type.
/// - An optional mutable component.
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
/// Parametric storage types must derive TypeStorage and respect the following:
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
/// instance of the type within its kind.
/// instance of the type.
/// * The key type must be constructible from the values passed into the
/// detail::TypeUniquer::get call after the type kind.
/// detail::TypeUniquer::get call.
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
/// storage class must define a hashing method:
/// 'static unsigned hashKey(const KeyTy &)'
@ -84,23 +83,6 @@ struct OpaqueTypeStorage;
// the key.
class Type {
public:
/// Integer identifier for all the concrete type kinds.
/// Note: This is not an enum class as each dialect will likely define a
/// separate enumeration for the specific types that they define. Not being an
/// enum class also simplifies the handling of type kinds by not requiring
/// casts for each use.
enum Kind {
// Builtin types.
Function,
Opaque,
LAST_BUILTIN_TYPE = Opaque,
// Reserve type kinds for dialect specific type system extensions.
#define DEFINE_SYM_KIND_RANGE(Dialect) \
FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
#include "DialectSymbolRegistry.def"
};
/// Utility class for implementing types.
template <typename ConcreteType, typename BaseType, typename StorageType,
template <typename T> class... Traits>
@ -136,9 +118,6 @@ public:
/// dynamic type casting.
TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
/// Return the classification for this type.
unsigned getKind() const;
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const;

View File

@ -11,12 +11,11 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
class TypeID;
namespace detail {
struct StorageUniquerImpl;
@ -29,22 +28,19 @@ template <typename ImplTy, typename T>
using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
} // namespace detail
/// A utility class to get, or create instances of storage classes. These
/// storage classes must respect the following constraints:
/// - Derive from StorageUniquer::BaseStorage.
/// - Provide an unsigned 'kind' value to be used as part of the unique'ing
/// process.
/// A utility class to get or create instances of "storage classes". These
/// storage classes must derive from 'StorageUniquer::BaseStorage'.
///
/// For non-parametric storage classes, i.e. those that are solely uniqued by
/// their kind, nothing else is needed. Instances of these classes can be
/// created by calling `get` without trailing arguments.
/// For non-parametric storage classes, i.e. singleton classes, nothing else is
/// needed. Instances of these classes can be created by calling `get` without
/// trailing arguments.
///
/// Otherwise, the parametric storage classes may be created with `get`,
/// and must respect the following:
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
/// instance of the storage class within its kind.
/// instance of the storage class.
/// * The key type must be constructible from the values passed into the
/// getComplex call after the kind.
/// getComplex call.
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
/// storage class must define a hashing method:
/// 'static unsigned hashKey(const KeyTy &)'
@ -83,32 +79,11 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
/// class.
class StorageUniquer {
public:
StorageUniquer();
~StorageUniquer();
/// Set the flag specifying if multi-threading is disabled within the uniquer.
void disableMultithreading(bool disable = true);
/// Register a new storage object with this uniquer using the given unique
/// type id.
void registerStorageType(TypeID id);
/// This class acts as the base storage that all storage classes must derived
/// from.
class BaseStorage {
public:
/// Get the kind classification of this storage.
unsigned getKind() const { return kind; }
protected:
BaseStorage() : kind(0) {}
private:
/// Allow access to the kind field.
friend detail::StorageUniquerImpl;
/// Classification of the subclass, used for type checking.
unsigned kind;
BaseStorage() = default;
};
/// This is a utility allocator used to allocate memory for instances of
@ -145,19 +120,61 @@ public:
llvm::BumpPtrAllocator allocator;
};
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
/// that can be used to initialize a newly inserted storage instance. This
/// function is used for derived types that have complex storage or uniquing
/// constraints.
template <typename Storage, typename Arg, typename... Args>
Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
unsigned kind, Arg &&arg, Args &&...args) {
// Construct a value of the derived key type.
auto derivedKey =
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
StorageUniquer();
~StorageUniquer();
// Create a hash of the kind and the derived key.
unsigned hashValue = getHash<Storage>(kind, derivedKey);
/// Set the flag specifying if multi-threading is disabled within the uniquer.
void disableMultithreading(bool disable = true);
/// Register a new parametric storage class, this is necessary to create
/// instances of this class type. `id` is the type identifier that will be
/// used to identify this type when creating instances of it via 'get'.
template <typename Storage> void registerParametricStorageType(TypeID id) {
registerParametricStorageTypeImpl(id);
}
/// Utility override when the storage type represents the type id.
template <typename Storage> void registerParametricStorageType() {
registerParametricStorageType<Storage>(TypeID::get<Storage>());
}
/// Register a new singleton storage class, this is necessary to get the
/// singletone instance. `id` is the type identifier that will be used to
/// access the singleton instance via 'get'. An optional initialization
/// function may also be provided to initialize the newly created storage
/// instance, and used when the singleton instance is created.
template <typename Storage>
void registerSingletonStorageType(TypeID id,
function_ref<void(Storage *)> initFn) {
auto ctorFn = [&](StorageAllocator &allocator) {
auto *storage = new (allocator.allocate<Storage>()) Storage();
if (initFn)
initFn(storage);
return storage;
};
registerSingletonImpl(id, ctorFn);
}
template <typename Storage> void registerSingletonStorageType(TypeID id) {
registerSingletonStorageType<Storage>(id, llvm::None);
}
/// Utility override when the storage type represents the type id.
template <typename Storage>
void registerSingletonStorageType(
function_ref<void(Storage *)> initFn = llvm::None) {
registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
}
/// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
/// registering the storage instance. 'initFn' is an optional parameter that
/// can be used to initialize a newly inserted storage instance. This function
/// is used for derived types that have complex storage or uniquing
/// constraints.
template <typename Storage, typename... Args>
Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
Args &&...args) {
// Construct a value of the derived key type.
auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
// Create a hash of the derived key.
unsigned hashValue = getHash<Storage>(derivedKey);
// Generate an equality function for the derived storage.
auto isEqual = [&derivedKey](const BaseStorage *existing) {
@ -174,29 +191,29 @@ public:
// Get an instance for the derived storage.
return static_cast<Storage *>(
getImpl(id, kind, hashValue, isEqual, ctorFn));
getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
}
/// Utility override when the storage type represents the type id.
template <typename Storage, typename... Args>
Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
return get<Storage>(initFn, TypeID::get<Storage>(),
std::forward<Args>(args)...);
}
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
/// that can be used to initialize a newly inserted storage instance. This
/// function is used for derived types that use no additional storage or
/// uniquing outside of the kind.
template <typename Storage>
Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
unsigned kind) {
auto ctorFn = [&](StorageAllocator &allocator) {
auto *storage = new (allocator.allocate<Storage>()) Storage();
if (initFn)
initFn(storage);
return storage;
};
return static_cast<Storage *>(getImpl(id, kind, ctorFn));
/// Gets a uniqued instance of 'Storage' which is a singleton storage type.
/// 'id' is the type id used when registering the storage instance.
template <typename Storage> Storage *get(TypeID id) {
return static_cast<Storage *>(getSingletonImpl(id));
}
/// Utility override when the storage type represents the type id.
template <typename Storage> Storage *get() {
return get<Storage>(TypeID::get<Storage>());
}
/// Changes the mutable component of 'storage' by forwarding the trailing
/// arguments to the 'mutate' function of the derived class.
template <typename Storage, typename... Args>
LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) {
LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
return static_cast<Storage &>(*storage).mutate(
allocator, std::forward<Args>(args)...);
@ -207,13 +224,13 @@ public:
/// Erases a uniqued instance of 'Storage'. This function is used for derived
/// types that have complex storage or uniquing constraints.
template <typename Storage, typename Arg, typename... Args>
void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) {
void erase(TypeID id, Arg &&arg, Args &&...args) {
// Construct a value of the derived key type.
auto derivedKey =
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
// Create a hash of the kind and the derived key.
unsigned hashValue = getHash<Storage>(kind, derivedKey);
// Create a hash of the derived key.
unsigned hashValue = getHash<Storage>(derivedKey);
// Generate an equality function for the derived storage.
auto isEqual = [&derivedKey](const BaseStorage *existing) {
@ -221,32 +238,42 @@ public:
};
// Attempt to erase the storage instance.
eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) {
eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) {
static_cast<Storage *>(storage)->cleanup();
});
}
private:
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// parametric storage.
BaseStorage *getParametricStorageTypeImpl(
TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for getting/creating an instance of a derived type with
/// default storage.
BaseStorage *getImpl(const TypeID &id, unsigned kind,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for registering an instance of a derived type with
/// parametric storage.
void registerParametricStorageTypeImpl(TypeID id);
/// Implementation for getting an instance of a derived type with default
/// storage.
BaseStorage *getSingletonImpl(TypeID id);
/// Implementation for registering an instance of a derived type with default
/// storage.
void
registerSingletonImpl(TypeID id,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for erasing an instance of a derived type with complex
/// storage.
void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue,
void eraseImpl(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<void(BaseStorage *)> cleanupFn);
/// Implementation for mutating an instance of a derived storage.
LogicalResult
mutateImpl(const TypeID &id,
mutateImpl(TypeID id,
function_ref<LogicalResult(StorageAllocator &)> mutationFn);
/// The internal implementation class.
@ -276,27 +303,26 @@ private:
}
//===--------------------------------------------------------------------===//
// Key and Kind Hashing
// Key Hashing
//===--------------------------------------------------------------------===//
/// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
/// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
/// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if
/// there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
getHash(unsigned kind, const DerivedKey &derivedKey) {
return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
getHash(const DerivedKey &derivedKey) {
return ImplTy::hashKey(derivedKey);
}
/// If there is no 'ImplTy::hashKey' default to using the
/// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
/// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo'
/// definition for 'DerivedKey' for generating a hash.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
getHash(unsigned kind, const DerivedKey &derivedKey) {
return llvm::hash_combine(
kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
getHash(const DerivedKey &derivedKey) {
return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
}
};
} // end namespace mlir

View File

@ -264,14 +264,13 @@ bool LLVMArrayType::isValidElementType(LLVMType type) {
LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType,
numElements);
return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements);
return Base::getChecked(loc, elementType, numElements);
}
LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
@ -301,16 +300,14 @@ LLVMFunctionType LLVMFunctionType::get(LLVMType result,
ArrayRef<LLVMType> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
return Base::get(result.getContext(), LLVMType::FunctionType, result,
arguments, isVarArg);
return Base::get(result.getContext(), result, arguments, isVarArg);
}
LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
ArrayRef<LLVMType> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
return Base::getChecked(loc, LLVMType::FunctionType, result, arguments,
isVarArg);
return Base::getChecked(loc, result, arguments, isVarArg);
}
LLVMType LLVMFunctionType::getReturnType() {
@ -347,11 +344,11 @@ LogicalResult LLVMFunctionType::verifyConstructionInvariants(
// Integer type.
LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
return Base::get(ctx, LLVMType::IntegerType, bitwidth);
return Base::get(ctx, bitwidth);
}
LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
return Base::getChecked(loc, LLVMType::IntegerType, bitwidth);
return Base::getChecked(loc, bitwidth);
}
unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
@ -374,13 +371,12 @@ bool LLVMPointerType::isValidElementType(LLVMType type) {
LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
assert(pointee && "expected non-null subtype");
return Base::get(pointee.getContext(), LLVMType::PointerType, pointee,
addressSpace);
return Base::get(pointee.getContext(), pointee, addressSpace);
}
LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
unsigned addressSpace) {
return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace);
return Base::getChecked(loc, pointee, addressSpace);
}
LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
@ -405,32 +401,32 @@ bool LLVMStructType::isValidElementType(LLVMType type) {
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
StringRef name) {
return Base::get(context, LLVMType::StructType, name, /*opaque=*/false);
return Base::get(context, name, /*opaque=*/false);
}
LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
StringRef name) {
return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false);
return Base::getChecked(loc, name, /*opaque=*/false);
}
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
ArrayRef<LLVMType> types,
bool isPacked) {
return Base::get(context, LLVMType::StructType, types, isPacked);
return Base::get(context, types, isPacked);
}
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
ArrayRef<LLVMType> types,
bool isPacked) {
return Base::getChecked(loc, LLVMType::StructType, types, isPacked);
return Base::getChecked(loc, types, isPacked);
}
LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
return Base::get(context, LLVMType::StructType, name, /*opaque=*/true);
return Base::get(context, name, /*opaque=*/true);
}
LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true);
return Base::getChecked(loc, name, /*opaque=*/true);
}
LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
@ -508,16 +504,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), LLVMType::FixedVectorType,
elementType, numElements);
return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
LLVMType elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, LLVMType::FixedVectorType, elementType,
numElements);
return Base::getChecked(loc, elementType, numElements);
}
unsigned LLVMFixedVectorType::getNumElements() {
@ -527,16 +521,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
return Base::get(elementType.getContext(), LLVMType::ScalableVectorType,
elementType, minNumElements);
return Base::get(elementType.getContext(), elementType, minNumElements);
}
LLVMScalableVectorType
LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType,
minNumElements);
return Base::getChecked(loc, elementType, minNumElements);
}
unsigned LLVMScalableVectorType::getMinNumElements() {

View File

@ -204,8 +204,8 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
Type expressedType,
int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
storageType, expressedType, storageTypeMin, storageTypeMax);
return Base::get(storageType.getContext(), flags, storageType, expressedType,
storageTypeMin, storageTypeMax);
}
AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
@ -213,8 +213,8 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
int64_t storageTypeMin,
int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
expressedType, storageTypeMin, storageTypeMax);
return Base::getChecked(location, flags, storageType, expressedType,
storageTypeMin, storageTypeMax);
}
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
@ -240,10 +240,8 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
int64_t zeroPoint,
int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(),
QuantizationTypes::UniformQuantized, flags, storageType,
expressedType, scale, zeroPoint, storageTypeMin,
storageTypeMax);
return Base::get(storageType.getContext(), flags, storageType, expressedType,
scale, zeroPoint, storageTypeMin, storageTypeMax);
}
UniformQuantizedType
@ -251,9 +249,8 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) {
return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
storageType, expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax);
return Base::getChecked(location, flags, storageType, expressedType, scale,
zeroPoint, storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
@ -295,10 +292,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::get(storageType.getContext(),
QuantizationTypes::UniformQuantizedPerAxis, flags,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
return Base::get(storageType.getContext(), flags, storageType, expressedType,
scales, zeroPoints, quantizedDimension, storageTypeMin,
storageTypeMax);
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
@ -306,9 +302,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
flags, storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
return Base::getChecked(location, flags, storageType, expressedType, scales,
zeroPoints, quantizedDimension, storageTypeMin,
storageTypeMax);
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(

View File

@ -13,11 +13,11 @@ using namespace mlir;
SDBMDialect::SDBMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
uniquer.registerStorageType(TypeID::get<detail::SDBMBinaryExprStorage>());
uniquer.registerStorageType(TypeID::get<detail::SDBMConstantExprStorage>());
uniquer.registerStorageType(TypeID::get<detail::SDBMDiffExprStorage>());
uniquer.registerStorageType(TypeID::get<detail::SDBMNegExprStorage>());
uniquer.registerStorageType(TypeID::get<detail::SDBMTermExprStorage>());
uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
}
SDBMDialect::~SDBMDialect() = default;

View File

@ -246,7 +246,6 @@ SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
return uniquer.get<detail::SDBMBinaryExprStorage>(
TypeID::get<detail::SDBMBinaryExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
}
@ -533,9 +532,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
assert(rhs && "expected SDBM dimension");
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
return uniquer.get<detail::SDBMDiffExprStorage>(
TypeID::get<detail::SDBMDiffExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
}
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
@ -575,7 +572,6 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
StorageUniquer &uniquer = var.getDialect()->getUniquer();
return uniquer.get<detail::SDBMBinaryExprStorage>(
TypeID::get<detail::SDBMBinaryExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
stripeFactor);
}
@ -611,8 +607,7 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMTermExprStorage>(
TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
static_cast<unsigned>(SDBMExprKind::DimId), position);
assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
}
//===----------------------------------------------------------------------===//
@ -628,8 +623,7 @@ SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMTermExprStorage>(
TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
static_cast<unsigned>(SDBMExprKind::SymbolId), position);
assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
}
//===----------------------------------------------------------------------===//
@ -644,9 +638,7 @@ SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
};
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMConstantExprStorage>(
TypeID::get<detail::SDBMConstantExprStorage>(), assignCtx,
static_cast<unsigned>(SDBMExprKind::Constant), value);
return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
}
int64_t SDBMConstantExpr::getValue() const {
@ -661,9 +653,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
assert(var && "expected non-null SDBM direct expression");
StorageUniquer &uniquer = var.getDialect()->getUniquer();
return uniquer.get<detail::SDBMNegExprStorage>(
TypeID::get<detail::SDBMNegExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
}
SDBMDirectExpr SDBMNegExpr::getVar() const {

View File

@ -25,27 +25,28 @@ namespace detail {
// Base storage class for SDBMExpr.
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
SDBMExprKind getKind() {
return static_cast<SDBMExprKind>(BaseStorage::getKind());
}
SDBMExprKind getKind() { return kind; }
SDBMDialect *dialect;
SDBMExprKind kind;
};
// Storage class for SDBM sum and stripe expressions.
struct SDBMBinaryExprStorage : public SDBMExprStorage {
using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
bool operator==(const KeyTy &key) const {
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
}
static SDBMBinaryExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
result->lhs = std::get<0>(key);
result->rhs = std::get<1>(key);
result->lhs = std::get<1>(key);
result->rhs = std::get<2>(key);
result->dialect = result->lhs.getDialect();
result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
return result;
}
@ -67,6 +68,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
result->lhs = std::get<0>(key);
result->rhs = std::get<1>(key);
result->dialect = result->lhs.getDialect();
result->kind = SDBMExprKind::Diff;
return result;
}
@ -84,6 +86,7 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMConstantExprStorage>();
result->constant = key;
result->kind = SDBMExprKind::Constant;
return result;
}
@ -92,14 +95,18 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
// Storage class for SDBM dimension and symbol expressions.
struct SDBMTermExprStorage : public SDBMExprStorage {
using KeyTy = unsigned;
using KeyTy = std::pair<unsigned, unsigned>;
bool operator==(const KeyTy &key) const { return position == key; }
bool operator==(const KeyTy &key) const {
return kind == static_cast<SDBMExprKind>(key.first) &&
position == key.second;
}
static SDBMTermExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMTermExprStorage>();
result->position = key;
result->kind = static_cast<SDBMExprKind>(key.first);
result->position = key.second;
return result;
}
@ -117,6 +124,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage {
auto *result = allocator.allocate<SDBMNegExprStorage>();
result->expr = key;
result->dialect = key.getDialect();
result->kind = SDBMExprKind::Neg;
return result;
}

View File

@ -120,8 +120,7 @@ spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass) {
assert(descriptorSet && binding);
MLIRContext *context = descriptorSet.getContext();
return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
binding, storageClass);
return Base::get(context, descriptorSet, binding, storageClass);
}
StringRef spirv::InterfaceVarABIAttr::getKindName() {
@ -195,8 +194,7 @@ spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
ArrayAttr extensions) {
assert(version && capabilities && extensions);
MLIRContext *context = version.getContext();
return Base::get(context, spirv::AttrKind::VerCapExt, version, capabilities,
extensions);
return Base::get(context, version, capabilities, extensions);
}
StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
@ -272,7 +270,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
DictionaryAttr limits) {
assert(triple && limits && "expected valid triple and limits");
MLIRContext *context = triple.getContext();
return Base::get(context, spirv::AttrKind::TargetEnv, triple, limits);
return Base::get(context, triple, limits);
}
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }

View File

@ -124,15 +124,14 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
assert(elementCount && "ArrayType needs at least one element");
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
elementCount, /*stride=*/0);
return Base::get(elementType.getContext(), elementType, elementCount,
/*stride=*/0);
}
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
unsigned stride) {
assert(elementCount && "ArrayType needs at least one element");
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
elementCount, stride);
return Base::get(elementType.getContext(), elementType, elementCount, stride);
}
unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
@ -285,8 +284,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
Scope scope, unsigned rows,
unsigned columns) {
return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
elementType, scope, rows, columns);
return Base::get(elementType.getContext(), elementType, scope, rows, columns);
}
Type CooperativeMatrixNVType::getElementType() const {
@ -389,7 +387,7 @@ ImageType
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
value) {
return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
return Base::get(std::get<0>(value).getContext(), value);
}
Type ImageType::getElementType() const { return getImpl()->elementType; }
@ -453,8 +451,7 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
};
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType,
storageClass);
return Base::get(pointeeType.getContext(), pointeeType, storageClass);
}
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
@ -511,13 +508,11 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
};
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
elementType, /*stride=*/0);
return Base::get(elementType.getContext(), elementType, /*stride=*/0);
}
RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
elementType, stride);
return Base::get(elementType.getContext(), elementType, stride);
}
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
@ -846,12 +841,12 @@ StructType::get(ArrayRef<Type> memberTypes,
SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
memberDecorations.begin(), memberDecorations.end());
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
memberTypes, offsetInfo, sortedDecorations);
return Base::get(memberTypes.vec().front().getContext(), memberTypes,
offsetInfo, sortedDecorations);
}
StructType StructType::getEmpty(MLIRContext *context) {
return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
return Base::get(context, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
ArrayRef<StructType::MemberDecorationInfo>());
}
@ -946,13 +941,12 @@ struct spirv::detail::MatrixTypeStorage : public TypeStorage {
};
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
columnCount);
return Base::get(columnType.getContext(), columnType, columnCount);
}
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
Location location) {
return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
return Base::getChecked(location, columnType, columnCount);
}
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,

View File

@ -20,9 +20,7 @@ using namespace mlir::detail;
MLIRContext *AffineExpr::getContext() const { return expr->context; }
AffineExprKind AffineExpr::getKind() const {
return static_cast<AffineExprKind>(expr->getKind());
}
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
/// Walk all of the AffineExprs in this subgraph in postorder.
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
@ -449,8 +447,7 @@ static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
StorageUniquer &uniquer = context->getAffineUniquer();
return uniquer.get<AffineDimExprStorage>(
TypeID::get<AffineDimExprStorage>(), assignCtx,
static_cast<unsigned>(kind), position);
assignCtx, static_cast<unsigned>(kind), position);
}
AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
@ -484,9 +481,7 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
};
StorageUniquer &uniquer = context->getAffineUniquer();
return uniquer.get<AffineConstantExprStorage>(
TypeID::get<AffineConstantExprStorage>(), assignCtx,
static_cast<unsigned>(AffineExprKind::Constant), constant);
return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
}
/// Simplify add expression. Return nullptr if it can't be simplified.
@ -594,7 +589,6 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
}
@ -655,7 +649,6 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
}
@ -722,7 +715,6 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
other);
}
@ -766,7 +758,6 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
other);
}
@ -814,7 +805,6 @@ AffineExpr AffineExpr::operator%(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
}

View File

@ -27,21 +27,24 @@ namespace detail {
/// Base storage class appearing in an affine expression.
struct AffineExprStorage : public StorageUniquer::BaseStorage {
MLIRContext *context;
AffineExprKind kind;
};
/// A binary operation appearing in an affine expression.
struct AffineBinaryOpExprStorage : public AffineExprStorage {
using KeyTy = std::pair<AffineExpr, AffineExpr>;
using KeyTy = std::tuple<unsigned, AffineExpr, AffineExpr>;
bool operator==(const KeyTy &key) const {
return key.first == lhs && key.second == rhs;
return static_cast<AffineExprKind>(std::get<0>(key)) == kind &&
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
}
static AffineBinaryOpExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
result->lhs = key.first;
result->rhs = key.second;
result->kind = static_cast<AffineExprKind>(std::get<0>(key));
result->lhs = std::get<1>(key);
result->rhs = std::get<2>(key);
result->context = result->lhs.getContext();
return result;
}
@ -52,14 +55,18 @@ struct AffineBinaryOpExprStorage : public AffineExprStorage {
/// A dimensional or symbolic identifier appearing in an affine expression.
struct AffineDimExprStorage : public AffineExprStorage {
using KeyTy = unsigned;
using KeyTy = std::pair<unsigned, unsigned>;
bool operator==(const KeyTy &key) const { return position == key; }
bool operator==(const KeyTy &key) const {
return kind == static_cast<AffineExprKind>(key.first) &&
position == key.second;
}
static AffineDimExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<AffineDimExprStorage>();
result->position = key;
result->kind = static_cast<AffineExprKind>(key.first);
result->position = key.second;
return result;
}
@ -76,6 +83,7 @@ struct AffineConstantExprStorage : public AffineExprStorage {
static AffineConstantExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<AffineConstantExprStorage>();
result->kind = AffineExprKind::Constant;
result->constant = key;
return result;
}

View File

@ -271,7 +271,7 @@ private:
/// Mapping between attribute kind and a pair comprised of a base alias name
/// and a unique list of attributes belonging to this kind sorted by location
/// seen in the module.
llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
llvm::MapVector<TypeID, std::pair<StringRef, std::vector<Attribute>>>
attrKindToAlias;
/// Set of types known to be used within the module.
@ -301,13 +301,13 @@ void AliasState::initialize(
llvm::StringSet<> usedAliases;
// Collect the set of aliases from each dialect.
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
SmallVector<std::pair<TypeID, StringRef>, 8> attributeKindAliases;
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
// AffineMap/Integer set have specific kind aliases.
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
attributeKindAliases.emplace_back(AffineMapAttr::getTypeID(), "map");
attributeKindAliases.emplace_back(IntegerSetAttr::getTypeID(), "set");
for (auto &interface : interfaces) {
interface.getAttributeKindAliases(attributeKindAliases);
@ -317,7 +317,7 @@ void AliasState::initialize(
// Setup the attribute kind aliases.
StringRef alias;
unsigned attrKind;
TypeID attrKind;
for (auto &attrAliasPair : attributeKindAliases) {
std::tie(attrKind, alias) = attrAliasPair;
assert(!alias.empty() && "expected non-empty alias string");
@ -420,7 +420,7 @@ void AliasState::recordAttributeReference(Attribute attr) {
return;
// If this attribute kind has an alias, then record one for this attribute.
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
auto alias = attrKindToAlias.find(attr.getTypeID());
if (alias == attrKindToAlias.end())
return;
std::pair<StringRef, int> attrAlias(alias->second.first,

View File

@ -57,7 +57,7 @@ Dialect &Attribute::getDialect() const {
//===----------------------------------------------------------------------===//
AffineMapAttr AffineMapAttr::get(AffineMap value) {
return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
return Base::get(value.getContext(), value);
}
AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
@ -67,7 +67,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
return Base::get(context, StandardAttributes::Array, value);
return Base::get(context, value);
}
ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
@ -156,7 +156,7 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
value = storage;
return Base::get(context, StandardAttributes::Dictionary, value);
return Base::get(context, value);
}
/// Construct a dictionary with an array of values that is known to already be
/// sorted by name and uniqued.
@ -175,7 +175,7 @@ DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
return l.first == r.first;
}) == value.end() &&
"DictionaryAttr element names must be unique");
return Base::get(context, StandardAttributes::Dictionary, value);
return Base::get(context, value);
}
ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
@ -219,19 +219,19 @@ size_t DictionaryAttr::size() const { return getValue().size(); }
//===----------------------------------------------------------------------===//
FloatAttr FloatAttr::get(Type type, double value) {
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
return Base::get(type.getContext(), type, value);
}
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
return Base::getChecked(loc, StandardAttributes::Float, type, value);
return Base::getChecked(loc, type, value);
}
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
return Base::get(type.getContext(), type, value);
}
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
return Base::getChecked(loc, StandardAttributes::Float, type, value);
return Base::getChecked(loc, type, value);
}
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@ -279,14 +279,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
//===----------------------------------------------------------------------===//
FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
.cast<FlatSymbolRefAttr>();
return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
}
SymbolRefAttr SymbolRefAttr::get(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences,
MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
return Base::get(ctx, value, nestedReferences);
}
StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
@ -307,7 +306,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
if (type.isSignlessInteger(1))
return BoolAttr::get(value.getBoolValue(), type.getContext());
return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
return Base::get(type.getContext(), type, value);
}
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
@ -380,8 +379,7 @@ bool BoolAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
return Base::get(value.getConstraint(0).getContext(),
StandardAttributes::IntegerSet, value);
return Base::get(value.getConstraint(0).getContext(), value);
}
IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
@ -392,14 +390,12 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
MLIRContext *context) {
return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
type);
return Base::get(context, dialect, attrData, type);
}
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
Type type, Location location) {
return Base::getChecked(location, StandardAttributes::Opaque, dialect,
attrData, type);
return Base::getChecked(location, dialect, attrData, type);
}
/// Returns the dialect namespace of the opaque attribute.
@ -430,7 +426,7 @@ StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
/// Get an instance of a StringAttr with the given string and Type.
StringAttr StringAttr::get(StringRef bytes, Type type) {
return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
return Base::get(type.getContext(), bytes, type);
}
StringRef StringAttr::getValue() const { return getImpl()->value; }
@ -440,7 +436,7 @@ StringRef StringAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
TypeAttr TypeAttr::get(Type value) {
return Base::get(value.getContext(), StandardAttributes::Type, value);
return Base::get(value.getContext(), value);
}
Type TypeAttr::getValue() const { return getImpl()->value; }
@ -1036,8 +1032,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
DenseStringElementsAttr
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
type, values, (values.size() == 1));
return Base::get(type.getContext(), type, values, (values.size() == 1));
}
//===----------------------------------------------------------------------===//
@ -1088,8 +1083,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
type, data, isSplat);
return Base::get(type.getContext(), type, data, isSplat);
}
/// Overload of the raw 'get' method that asserts that the given type is of
@ -1210,8 +1204,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
StringRef bytes) {
assert(TensorType::isValidElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
dialect, bytes);
return Base::get(type.getContext(), type, dialect, bytes);
}
StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
@ -1248,7 +1241,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
return Base::get(type.getContext(), type,
indices.cast<DenseIntElementsAttr>(), values);
}

View File

@ -28,8 +28,7 @@ bool LocationAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
Location CallSiteLoc::get(Location callee, Location caller) {
return Base::get(callee->getContext(), StandardAttributes::CallSiteLocation,
callee, caller);
return Base::get(callee->getContext(), callee, caller);
}
Location CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
@ -50,8 +49,7 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
Location FileLineColLoc::get(Identifier filename, unsigned line,
unsigned column, MLIRContext *context) {
return Base::get(context, StandardAttributes::FileLineColLocation, filename,
line, column);
return Base::get(context, filename, line, column);
}
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
@ -95,7 +93,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
return UnknownLoc::get(context);
if (locs.size() == 1)
return locs.front();
return Base::get(context, StandardAttributes::FusedLocation, locs, metadata);
return Base::get(context, locs, metadata);
}
ArrayRef<Location> FusedLoc::getLocations() const {
@ -111,8 +109,7 @@ Attribute FusedLoc::getMetadata() const { return getImpl()->metadata; }
Location NameLoc::get(Identifier name, Location child) {
assert(!child.isa<NameLoc>() &&
"a NameLoc cannot be used as a child of another NameLoc");
return Base::get(child->getContext(), StandardAttributes::NameLocation, name,
child);
return Base::get(child->getContext(), name, child);
}
Location NameLoc::get(Identifier name, MLIRContext *context) {
@ -131,9 +128,8 @@ Location NameLoc::getChildLoc() const { return getImpl()->child; }
Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
Location fallbackLocation) {
return Base::get(fallbackLocation->getContext(),
StandardAttributes::OpaqueLocation, underlyingLocation,
typeID, fallbackLocation);
return Base::get(fallbackLocation->getContext(), underlyingLocation, typeID,
fallbackLocation);
}
uintptr_t OpaqueLoc::getUnderlyingLocation() const {

View File

@ -87,6 +87,10 @@ namespace {
struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context)
: Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
FunctionType, IndexType, IntegerType, MemRefType,
UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
TupleType, UnrankedTensorType, VectorType>();
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
@ -95,11 +99,6 @@ struct BuiltinDialect : public Dialect {
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>();
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
FunctionType, IndexType, IntegerType, MemRefType,
UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
TupleType, UnrankedTensorType, VectorType>();
// TODO: These operations should be moved to a different dialect when they
// have been fully decoupled from the core.
addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
@ -363,56 +362,50 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
//// Types.
/// Floating-point Types.
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
impl->f64Ty = TypeUniquer::get<Float64Type>(this);
/// Index Type.
impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
impl->indexTy = TypeUniquer::get<IndexType>(this);
/// Integer Types.
impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1,
IntegerType::Signless);
impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8,
IntegerType::Signless);
impl->int16Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
16, IntegerType::Signless);
impl->int32Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
32, IntegerType::Signless);
impl->int64Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
64, IntegerType::Signless);
impl->int128Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
128, IntegerType::Signless);
impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
impl->int16Ty =
TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
impl->int32Ty =
TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
impl->int64Ty =
TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
impl->int128Ty =
TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
/// None Type.
impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
impl->noneType = TypeUniquer::get<NoneType>(this);
//// Attributes.
//// Note: These must be registered after the types as they may generate one
//// of the above types internally.
/// Bool Attributes.
impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
this, StandardAttributes::Integer, impl->int1Ty,
APInt(/*numBits=*/1, false))
this, impl->int1Ty, APInt(/*numBits=*/1, false))
.cast<BoolAttr>();
impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
this, StandardAttributes::Integer, impl->int1Ty,
APInt(/*numBits=*/1, true))
this, impl->int1Ty, APInt(/*numBits=*/1, true))
.cast<BoolAttr>();
/// Unit Attribute.
impl->unitAttr =
AttributeUniquer::get<UnitAttr>(this, StandardAttributes::Unit);
impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
/// Unknown Location Attribute.
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
this, StandardAttributes::UnknownLocation);
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
/// The empty dictionary attribute.
impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
impl->emptyDictionaryAttr =
AttributeUniquer::get<DictionaryAttr>(this, ArrayRef<NamedAttribute>());
// Register the affine storage objects with the uniquer.
impl->affineUniquer.registerStorageType(
TypeID::get<AffineBinaryOpExprStorage>());
impl->affineUniquer.registerStorageType(
TypeID::get<AffineConstantExprStorage>());
impl->affineUniquer.registerStorageType(TypeID::get<AffineDimExprStorage>());
impl->affineUniquer
.registerParametricStorageType<AffineBinaryOpExprStorage>();
impl->affineUniquer
.registerParametricStorageType<AffineConstantExprStorage>();
impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
}
MLIRContext::~MLIRContext() {}
@ -582,7 +575,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
AbstractType(std::move(typeInfo));
if (!impl.registeredTypes.insert({typeID, newInfo}).second)
llvm::report_fatal_error("Dialect Type already registered.");
impl.typeUniquer.registerStorageType(typeID);
}
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
@ -592,7 +584,6 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
AbstractAttribute(std::move(attrInfo));
if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
llvm::report_fatal_error("Dialect Attribute already registered.");
impl.attributeUniquer.registerStorageType(typeID);
}
/// Get the dialect that registered the attribute with the provided typeid.
@ -718,7 +709,7 @@ IntegerType IntegerType::get(unsigned width,
MLIRContext *context) {
if (auto cached = getCachedIntegerType(width, signedness, context))
return cached;
return Base::get(context, StandardTypes::Integer, width, signedness);
return Base::get(context, width, signedness);
}
IntegerType IntegerType::getChecked(unsigned width, Location location) {
@ -731,12 +722,16 @@ IntegerType IntegerType::getChecked(unsigned width,
if (auto cached =
getCachedIntegerType(width, signedness, location->getContext()))
return cached;
return Base::getChecked(location, StandardTypes::Integer, width, signedness);
return Base::getChecked(location, width, signedness);
}
/// Get an instance of the NoneType.
NoneType NoneType::get(MLIRContext *context) {
return context->getImpl().noneType;
if (NoneType cachedInst = context->getImpl().noneType)
return cachedInst;
// Note: May happen when initializing the singleton attributes of the builtin
// dialect.
return Base::get(context);
}
//===----------------------------------------------------------------------===//

View File

@ -102,12 +102,11 @@ unsigned Type::getIntOrFloatBitWidth() {
//===----------------------------------------------------------------------===//
ComplexType ComplexType::get(Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::Complex,
elementType);
return Base::get(elementType.getContext(), elementType);
}
ComplexType ComplexType::getChecked(Type elementType, Location location) {
return Base::getChecked(location, StandardTypes::Complex, elementType);
return Base::getChecked(location, elementType);
}
/// Verify the construction of an integer type.
@ -265,13 +264,12 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
//===----------------------------------------------------------------------===//
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
elementType);
return Base::get(elementType.getContext(), shape, elementType);
}
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location) {
return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
return Base::getChecked(location, shape, elementType);
}
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
@ -320,15 +318,13 @@ bool TensorType::isValidElementType(Type type) {
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
elementType);
return Base::get(elementType.getContext(), shape, elementType);
}
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
Type elementType,
Location location) {
return Base::getChecked(location, StandardTypes::RankedTensor, shape,
elementType);
return Base::getChecked(location, shape, elementType);
}
LogicalResult RankedTensorType::verifyConstructionInvariants(
@ -349,13 +345,12 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
//===----------------------------------------------------------------------===//
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
elementType);
return Base::get(elementType.getContext(), elementType);
}
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
Location location) {
return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
return Base::getChecked(location, elementType);
}
LogicalResult
@ -444,8 +439,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
cleanedAffineMapComposition.push_back(map);
}
return Base::get(context, StandardTypes::MemRef, shape, elementType,
cleanedAffineMapComposition, memorySpace);
return Base::get(context, shape, elementType, cleanedAffineMapComposition,
memorySpace);
}
ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
@ -462,15 +457,13 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
unsigned memorySpace) {
return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
elementType, memorySpace);
return Base::get(elementType.getContext(), elementType, memorySpace);
}
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
unsigned memorySpace,
Location location) {
return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
memorySpace);
return Base::getChecked(location, elementType, memorySpace);
}
unsigned UnrankedMemRefType::getMemorySpace() const {
@ -642,7 +635,7 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
return Base::get(context, StandardTypes::Tuple, elementTypes);
return Base::get(context, elementTypes);
}
/// Get or create an empty tuple type.

View File

@ -19,8 +19,6 @@ using namespace mlir::detail;
// Type
//===----------------------------------------------------------------------===//
unsigned Type::getKind() const { return impl->getKind(); }
Dialect &Type::getDialect() const {
return impl->getAbstractType().getDialect();
}
@ -33,7 +31,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
MLIRContext *context) {
return Base::get(context, Type::Kind::Function, inputs, results);
return Base::get(context, inputs, results);
}
unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
@ -54,12 +52,12 @@ ArrayRef<Type> FunctionType::getResults() const {
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
MLIRContext *context) {
return Base::get(context, Type::Kind::Opaque, dialect, typeData);
return Base::get(context, dialect, typeData);
}
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
MLIRContext *context, Location location) {
return Base::getChecked(location, Kind::Opaque, dialect, typeData);
return Base::getChecked(location, dialect, typeData);
}
/// Returns the dialect namespace of the opaque type.

View File

@ -16,19 +16,17 @@ using namespace mlir;
using namespace mlir::detail;
namespace {
/// This class represents a uniquer for storage instances of a specific type. It
/// contains all of the necessary data to unique storage instances in a thread
/// safe way. This allows for the main uniquer to bucket each of the individual
/// sub-types removing the need to lock the main uniquer itself.
struct InstSpecificUniquer {
/// This class represents a uniquer for storage instances of a specific type
/// that has parametric storage. It contains all of the necessary data to unique
/// storage instances in a thread safe way. This allows for the main uniquer to
/// bucket each of the individual sub-types removing the need to lock the main
/// uniquer itself.
struct ParametricStorageUniquer {
using BaseStorage = StorageUniquer::BaseStorage;
using StorageAllocator = StorageUniquer::StorageAllocator;
/// A lookup key for derived instances of storage objects.
struct LookupKey {
/// The known derived kind for the storage.
unsigned kind;
/// The known hash value of the key.
unsigned hashValue;
@ -63,18 +61,14 @@ struct InstSpecificUniquer {
static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
return false;
// If the lookup kind matches the kind of the storage, then invoke the
// equality function on the lookup key.
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
// Invoke the equality function on the lookup key.
return lhs.isEqual(rhs.storage);
}
};
/// Unique types with specific hashing or storage constraints.
/// The set containing the allocated storage instances.
using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
StorageTypeSet complexInstances;
/// Instances of this storage object.
llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
StorageTypeSet instances;
/// Allocator to use when constructing derived instances.
StorageAllocator allocator;
@ -91,107 +85,79 @@ struct StorageUniquerImpl {
using BaseStorage = StorageUniquer::BaseStorage;
using StorageAllocator = StorageUniquer::StorageAllocator;
/// Get or create an instance of a complex derived type.
//===--------------------------------------------------------------------===//
// Parametric Storage
//===--------------------------------------------------------------------===//
/// Get or create an instance of a parametric type.
BaseStorage *
getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
getOrCreate(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
assert(instUniquers.count(id) && "creating unregistered storage instance");
InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
InstSpecificUniquer &storageUniquer = *instUniquers[id];
assert(parametricUniquers.count(id) &&
"creating unregistered storage instance");
ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
if (!threadingIsEnabled)
return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
// Check for an existing instance in read-only mode.
{
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
auto it = storageUniquer.complexInstances.find_as(lookupKey);
if (it != storageUniquer.complexInstances.end())
auto it = storageUniquer.instances.find_as(lookupKey);
if (it != storageUniquer.instances.end())
return it->storage;
}
// Acquire a writer-lock so that we can safely create the new type instance.
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
}
/// Get or create an instance of a complex derived type in an thread-unsafe
/// fashion.
BaseStorage *
getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
InstSpecificUniquer::LookupKey &lookupKey,
getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer,
ParametricStorageUniquer::LookupKey &lookupKey,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
auto existing = storageUniquer.instances.insert_as({}, lookupKey);
if (!existing.second)
return existing.first->storage;
// Otherwise, construct and initialize the derived storage for this type
// instance.
BaseStorage *storage =
initializeStorage(kind, storageUniquer.allocator, ctorFn);
BaseStorage *storage = ctorFn(storageUniquer.allocator);
*existing.first =
InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage};
return storage;
}
/// Get or create an instance of a simple derived type.
BaseStorage *
getOrCreate(TypeID id, unsigned kind,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
assert(instUniquers.count(id) && "creating unregistered storage instance");
InstSpecificUniquer &storageUniquer = *instUniquers[id];
if (!threadingIsEnabled)
return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
// Check for an existing instance in read-only mode.
{
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
auto it = storageUniquer.simpleInstances.find(kind);
if (it != storageUniquer.simpleInstances.end())
return it->second;
}
// Acquire a writer-lock so that we can safely create the new type instance.
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
}
/// Get or create an instance of a simple derived type in an thread-unsafe
/// fashion.
BaseStorage *
getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
auto &result = storageUniquer.simpleInstances[kind];
if (result)
return result;
// Otherwise, create and return a new storage instance.
return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
}
/// Erase an instance of a complex derived type.
void erase(TypeID id, unsigned kind, unsigned hashValue,
/// Erase an instance of a parametric derived type.
void erase(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<void(BaseStorage *)> cleanupFn) {
assert(instUniquers.count(id) && "erasing unregistered storage instance");
InstSpecificUniquer &storageUniquer = *instUniquers[id];
InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
assert(parametricUniquers.count(id) &&
"erasing unregistered storage instance");
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
// Acquire a writer-lock so that we can safely erase the type instance.
llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
auto existing = storageUniquer.complexInstances.find_as(lookupKey);
if (existing == storageUniquer.complexInstances.end())
auto existing = storageUniquer.instances.find_as(lookupKey);
if (existing == storageUniquer.instances.end())
return;
// Cleanup the storage and remove it from the map.
cleanupFn(existing->storage);
storageUniquer.complexInstances.erase(existing);
storageUniquer.instances.erase(existing);
}
/// Mutates an instance of a derived storage in a thread-safe way.
LogicalResult
mutate(TypeID id,
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
assert(instUniquers.count(id) && "mutating unregistered storage instance");
InstSpecificUniquer &storageUniquer = *instUniquers[id];
assert(parametricUniquers.count(id) &&
"mutating unregistered storage instance");
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
if (!threadingIsEnabled)
return mutationFn(storageUniquer.allocator);
@ -199,21 +165,31 @@ struct StorageUniquerImpl {
return mutationFn(storageUniquer.allocator);
}
//===--------------------------------------------------------------------===//
// Singleton Storage
//===--------------------------------------------------------------------===//
/// Get or create an instance of a singleton storage class.
BaseStorage *getSingleton(TypeID id) {
BaseStorage *singletonInstance = singletonInstances[id];
assert(singletonInstance && "expected singleton instance to exist");
return singletonInstance;
}
//===--------------------------------------------------------------------===//
// Instance Storage
//===--------------------------------------------------------------------===//
/// Utility to create and initialize a storage instance.
BaseStorage *
initializeStorage(unsigned kind, StorageAllocator &allocator,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
BaseStorage *storage = ctorFn(allocator);
storage->kind = kind;
return storage;
}
/// Map of type ids to the storage uniquer to use for registered objects.
DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
parametricUniquers;
/// Map of type ids to a singleton instance when the storage class is a
/// singleton.
DenseMap<TypeID, BaseStorage *> singletonInstances;
/// Allocator used for uniquing singleton instances.
StorageAllocator singletonAllocator;
/// Flag specifying if multi-threading is enabled within the uniquer.
bool threadingIsEnabled = true;
@ -229,41 +205,47 @@ void StorageUniquer::disableMultithreading(bool disable) {
impl->threadingIsEnabled = !disable;
}
/// Register a new storage object with this uniquer using the given unique type
/// id.
void StorageUniquer::registerStorageType(TypeID id) {
impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
}
/// Implementation for getting/creating an instance of a derived type with
/// complex storage.
auto StorageUniquer::getImpl(
const TypeID &id, unsigned kind, unsigned hashValue,
/// parametric storage.
auto StorageUniquer::getParametricStorageTypeImpl(
TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
}
/// Implementation for getting/creating an instance of a derived type with
/// default storage.
auto StorageUniquer::getImpl(
const TypeID &id, unsigned kind,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
return impl->getOrCreate(id, kind, ctorFn);
/// Implementation for registering an instance of a derived type with
/// parametric storage.
void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
impl->parametricUniquers.try_emplace(
id, std::make_unique<ParametricStorageUniquer>());
}
/// Implementation for erasing an instance of a derived type with complex
/// Implementation for getting an instance of a derived type with default
/// storage.
void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
unsigned hashValue,
auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
return impl->getSingleton(id);
}
/// Implementation for registering an instance of a derived type with default
/// storage.
void StorageUniquer::registerSingletonImpl(
TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
assert(!impl->singletonInstances.count(id) &&
"storage class already registered");
impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
}
/// Implementation for erasing an instance of a derived type with parametric
/// storage.
void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<void(BaseStorage *)> cleanupFn) {
impl->erase(id, kind, hashValue, isEqual, cleanupFn);
impl->erase(id, hashValue, isEqual, cleanupFn);
}
/// Implementation for mutating an instance of a derived storage.
LogicalResult StorageUniquer::mutateImpl(
const TypeID &id,
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
TypeID id, function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
return impl->mutate(id, mutationFn);
}

View File

@ -156,7 +156,7 @@ static Type parseTestType(DialectAsmParser &parser,
StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
// If this type already has been parsed above in the stack, expect just the
// name.

View File

@ -26,10 +26,6 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
TestTypeInterface::Trait> {
using Base::Base;
static TestType get(MLIRContext *context) {
return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
}
/// Provide a definition for the necessary interface methods.
void printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
@ -72,9 +68,8 @@ class TestRecursiveType
public:
using Base::Base;
static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
name);
static TestRecursiveType get(MLIRContext *ctx, StringRef name) {
return Base::get(ctx, name);
}
/// Body getter and setter.

View File

@ -41,7 +41,7 @@ struct TestRecursiveTypesPass
LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
MLIRContext *ctx = &getContext();
FuncOp func = getFunction();
auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
if (failed(type.setBody(type)))
return func.emitError("expected to be able to set the type body");
@ -56,7 +56,7 @@ LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
"not expected to be able to change function body more than once");
// Expecting to get the same type for the same name.
auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
if (type != other)
return func.emitError("expected type name to be the uniquing key");