forked from OSchip/llvm-project
[mlir] Remove the use of "kinds" from Attributes and Types
This greatly simplifies a large portion of the underlying infrastructure, allows for lookups of singleton classes to be much more efficient and always thread-safe(no locking). As a result of this, the dialect symbol registry has been removed as it is no longer necessary. For users broken by this change, an alert was sent out(https://llvm.discourse.group/t/removing-kinds-from-attributes-and-types) that helps prevent a majority of the breakage surface area. All that should be necessary, if the advice in that alert was followed, is removing the kind passed to the ::get methods. Differential Revision: https://reviews.llvm.org/D86121
This commit is contained in:
parent
a7d0b7a786
commit
250f43d3ec
|
@ -25,17 +25,6 @@ struct RealAttributeStorage;
|
|||
struct TypeAttributeStorage;
|
||||
} // namespace detail
|
||||
|
||||
enum AttributeKind {
|
||||
FIR_ATTR = mlir::Attribute::FIRST_FIR_ATTR,
|
||||
FIR_EXACTTYPE, // instance_of, precise type relation
|
||||
FIR_SUBCLASS, // subsumed_by, is-a (subclass) relation
|
||||
FIR_POINT,
|
||||
FIR_CLOSEDCLOSED_INTERVAL,
|
||||
FIR_OPENCLOSED_INTERVAL,
|
||||
FIR_CLOSEDOPEN_INTERVAL,
|
||||
FIR_REAL_ATTR,
|
||||
};
|
||||
|
||||
class ExactTypeAttr
|
||||
: public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
|
||||
detail::TypeAttributeStorage> {
|
||||
|
@ -47,8 +36,6 @@ public:
|
|||
static ExactTypeAttr get(mlir::Type value);
|
||||
|
||||
mlir::Type getType() const;
|
||||
|
||||
static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; }
|
||||
};
|
||||
|
||||
class SubclassAttr
|
||||
|
@ -62,8 +49,6 @@ public:
|
|||
static SubclassAttr get(mlir::Type value);
|
||||
|
||||
mlir::Type getType() const;
|
||||
|
||||
static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; }
|
||||
};
|
||||
|
||||
// Attributes for building SELECT CASE multiway branches
|
||||
|
@ -80,9 +65,6 @@ public:
|
|||
|
||||
static constexpr llvm::StringRef getAttrName() { return "interval"; }
|
||||
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
|
||||
static constexpr unsigned getId() {
|
||||
return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL;
|
||||
}
|
||||
};
|
||||
|
||||
/// An upper bound is an open interval (including the bound value) as given as
|
||||
|
@ -97,9 +79,6 @@ public:
|
|||
|
||||
static constexpr llvm::StringRef getAttrName() { return "upper"; }
|
||||
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
|
||||
static constexpr unsigned getId() {
|
||||
return AttributeKind::FIR_OPENCLOSED_INTERVAL;
|
||||
}
|
||||
};
|
||||
|
||||
/// A lower bound is an open interval (including the bound value) as given as
|
||||
|
@ -114,9 +93,6 @@ public:
|
|||
|
||||
static constexpr llvm::StringRef getAttrName() { return "lower"; }
|
||||
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
|
||||
static constexpr unsigned getId() {
|
||||
return AttributeKind::FIR_CLOSEDOPEN_INTERVAL;
|
||||
}
|
||||
};
|
||||
|
||||
/// A pointer interval is a closed interval as given as an ssa-value. The
|
||||
|
@ -131,7 +107,6 @@ public:
|
|||
|
||||
static constexpr llvm::StringRef getAttrName() { return "point"; }
|
||||
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
|
||||
static constexpr unsigned getId() { return AttributeKind::FIR_POINT; }
|
||||
};
|
||||
|
||||
/// A real attribute is used to workaround MLIR's default parsing of a real
|
||||
|
@ -150,8 +125,6 @@ public:
|
|||
|
||||
int getFKind() const;
|
||||
llvm::APFloat getValue() const;
|
||||
|
||||
static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; }
|
||||
};
|
||||
|
||||
mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
|
||||
|
|
|
@ -54,29 +54,6 @@ struct SequenceTypeStorage;
|
|||
struct TypeDescTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// Integral identifier for all the types comprising the FIR type system
|
||||
enum TypeKind {
|
||||
// The enum starts at the range reserved for this dialect.
|
||||
FIR_TYPE = mlir::Type::FIRST_FIR_TYPE,
|
||||
FIR_BOX, // (static) descriptor
|
||||
FIR_BOXCHAR, // CHARACTER pointer and length
|
||||
FIR_BOXPROC, // procedure with host association
|
||||
FIR_CHARACTER, // intrinsic type
|
||||
FIR_COMPLEX, // intrinsic type
|
||||
FIR_DERIVED, // derived
|
||||
FIR_DIMS,
|
||||
FIR_FIELD,
|
||||
FIR_HEAP,
|
||||
FIR_INT, // intrinsic type
|
||||
FIR_LEN,
|
||||
FIR_LOGICAL, // intrinsic type
|
||||
FIR_POINTER, // POINTER attr
|
||||
FIR_REAL, // intrinsic type
|
||||
FIR_REFERENCE,
|
||||
FIR_SEQUENCE, // DIMENSION attr
|
||||
FIR_TYPEDESC,
|
||||
};
|
||||
|
||||
// These isa_ routines follow the precedent of llvm::isa_or_null<>
|
||||
|
||||
/// Is `t` any of the FIR dialect types?
|
||||
|
@ -111,12 +88,6 @@ bool isa_aggregate(mlir::Type t);
|
|||
/// not a memory reference type, then returns a null `Type`.
|
||||
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
|
||||
|
||||
/// Boilerplate mixin template
|
||||
template <typename A, unsigned Id>
|
||||
struct IntrinsicTypeMixin {
|
||||
static constexpr unsigned getId() { return Id; }
|
||||
};
|
||||
|
||||
// Intrinsic types
|
||||
|
||||
/// Model of the Fortran CHARACTER intrinsic type, including the KIND type
|
||||
|
@ -124,8 +95,7 @@ struct IntrinsicTypeMixin {
|
|||
/// is thus the type of a single character value.
|
||||
class CharacterType
|
||||
: public mlir::Type::TypeBase<CharacterType, mlir::Type,
|
||||
detail::CharacterTypeStorage>,
|
||||
public IntrinsicTypeMixin<CharacterType, TypeKind::FIR_CHARACTER> {
|
||||
detail::CharacterTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||
|
@ -136,8 +106,7 @@ public:
|
|||
/// parameter. COMPLEX is a floating point type with a real and imaginary
|
||||
/// member.
|
||||
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
|
||||
detail::CplxTypeStorage>,
|
||||
public IntrinsicTypeMixin<CplxType, TypeKind::FIR_COMPLEX> {
|
||||
detail::CplxTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||
|
@ -151,8 +120,7 @@ public:
|
|||
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
|
||||
/// parameter.
|
||||
class IntType
|
||||
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage>,
|
||||
public IntrinsicTypeMixin<IntType, TypeKind::FIR_INT> {
|
||||
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||
|
@ -163,8 +131,7 @@ public:
|
|||
/// parameter.
|
||||
class LogicalType
|
||||
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
|
||||
detail::LogicalTypeStorage>,
|
||||
public IntrinsicTypeMixin<LogicalType, TypeKind::FIR_LOGICAL> {
|
||||
detail::LogicalTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||
|
@ -174,8 +141,7 @@ public:
|
|||
/// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
|
||||
/// KIND type parameter.
|
||||
class RealType : public mlir::Type::TypeBase<RealType, mlir::Type,
|
||||
detail::RealTypeStorage>,
|
||||
public IntrinsicTypeMixin<RealType, TypeKind::FIR_REAL> {
|
||||
detail::RealTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static RealType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||
|
@ -400,7 +366,6 @@ public:
|
|||
static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
|
||||
void finalize(llvm::ArrayRef<TypePair> lenPList,
|
||||
llvm::ArrayRef<TypePair> typeList);
|
||||
static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; }
|
||||
|
||||
detail::RecordTypeStorage const *uniqueKey() const;
|
||||
|
||||
|
|
|
@ -74,13 +74,13 @@ private:
|
|||
} // namespace detail
|
||||
|
||||
ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
|
||||
return Base::get(value.getContext(), FIR_EXACTTYPE, value);
|
||||
return Base::get(value.getContext(), value);
|
||||
}
|
||||
|
||||
mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
|
||||
|
||||
SubclassAttr SubclassAttr::get(mlir::Type value) {
|
||||
return Base::get(value.getContext(), FIR_SUBCLASS, value);
|
||||
return Base::get(value.getContext(), value);
|
||||
}
|
||||
|
||||
mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
|
||||
|
@ -88,26 +88,26 @@ mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
|
|||
using AttributeUniquer = mlir::detail::AttributeUniquer;
|
||||
|
||||
ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
|
||||
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt, getId());
|
||||
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
|
||||
}
|
||||
|
||||
UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
|
||||
return AttributeUniquer::get<UpperBoundAttr>(ctxt, getId());
|
||||
return AttributeUniquer::get<UpperBoundAttr>(ctxt);
|
||||
}
|
||||
|
||||
LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
|
||||
return AttributeUniquer::get<LowerBoundAttr>(ctxt, getId());
|
||||
return AttributeUniquer::get<LowerBoundAttr>(ctxt);
|
||||
}
|
||||
|
||||
PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
|
||||
return AttributeUniquer::get<PointIntervalAttr>(ctxt, getId());
|
||||
return AttributeUniquer::get<PointIntervalAttr>(ctxt);
|
||||
}
|
||||
|
||||
// RealAttr
|
||||
|
||||
RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
|
||||
const RealAttr::ValueType &key) {
|
||||
return Base::get(ctxt, getId(), key);
|
||||
return Base::get(ctxt, key);
|
||||
}
|
||||
|
||||
int RealAttr::getFKind() const { return getImpl()->getFKind(); }
|
||||
|
|
|
@ -824,13 +824,11 @@ bool inbounds(A v, B lb, B ub) {
|
|||
}
|
||||
|
||||
bool isa_fir_type(mlir::Type t) {
|
||||
return inbounds(t.getKind(), mlir::Type::FIRST_FIR_TYPE,
|
||||
mlir::Type::LAST_FIR_TYPE);
|
||||
return llvm::isa<FIROpsDialect>(t.getDialect());
|
||||
}
|
||||
|
||||
bool isa_std_type(mlir::Type t) {
|
||||
return inbounds(t.getKind(), mlir::Type::FIRST_STANDARD_TYPE,
|
||||
mlir::Type::LAST_STANDARD_TYPE);
|
||||
return t.getDialect().getNamespace().empty();
|
||||
}
|
||||
|
||||
bool isa_fir_or_std_type(mlir::Type t) {
|
||||
|
@ -868,7 +866,7 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
|
|||
// CHARACTER
|
||||
|
||||
CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||
return Base::get(ctxt, FIR_CHARACTER, kind);
|
||||
return Base::get(ctxt, kind);
|
||||
}
|
||||
|
||||
int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
|
||||
|
@ -876,7 +874,7 @@ int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
|
|||
// Dims
|
||||
|
||||
DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) {
|
||||
return Base::get(ctxt, FIR_DIMS, rank);
|
||||
return Base::get(ctxt, rank);
|
||||
}
|
||||
|
||||
unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
|
||||
|
@ -884,19 +882,19 @@ unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
|
|||
// Field
|
||||
|
||||
FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) {
|
||||
return Base::get(ctxt, FIR_FIELD, 0);
|
||||
return Base::get(ctxt, 0);
|
||||
}
|
||||
|
||||
// Len
|
||||
|
||||
LenType fir::LenType::get(mlir::MLIRContext *ctxt) {
|
||||
return Base::get(ctxt, FIR_LEN, 0);
|
||||
return Base::get(ctxt, 0);
|
||||
}
|
||||
|
||||
// LOGICAL
|
||||
|
||||
LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||
return Base::get(ctxt, FIR_LOGICAL, kind);
|
||||
return Base::get(ctxt, kind);
|
||||
}
|
||||
|
||||
int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
|
||||
|
@ -904,7 +902,7 @@ int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
|
|||
// INTEGER
|
||||
|
||||
IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||
return Base::get(ctxt, FIR_INT, kind);
|
||||
return Base::get(ctxt, kind);
|
||||
}
|
||||
|
||||
int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
|
||||
|
@ -912,7 +910,7 @@ int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
|
|||
// COMPLEX
|
||||
|
||||
CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||
return Base::get(ctxt, FIR_COMPLEX, kind);
|
||||
return Base::get(ctxt, kind);
|
||||
}
|
||||
|
||||
mlir::Type fir::CplxType::getElementType() const {
|
||||
|
@ -924,7 +922,7 @@ KindTy fir::CplxType::getFKind() const { return getImpl()->getFKind(); }
|
|||
// REAL
|
||||
|
||||
RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||
return Base::get(ctxt, FIR_REAL, kind);
|
||||
return Base::get(ctxt, kind);
|
||||
}
|
||||
|
||||
int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
|
||||
|
@ -932,7 +930,7 @@ int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
|
|||
// Box<T>
|
||||
|
||||
BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) {
|
||||
return Base::get(elementType.getContext(), FIR_BOX, elementType, map);
|
||||
return Base::get(elementType.getContext(), elementType, map);
|
||||
}
|
||||
|
||||
mlir::Type fir::BoxType::getEleTy() const {
|
||||
|
@ -953,7 +951,7 @@ fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
|
|||
// BoxChar<C>
|
||||
|
||||
BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||
return Base::get(ctxt, FIR_BOXCHAR, kind);
|
||||
return Base::get(ctxt, kind);
|
||||
}
|
||||
|
||||
CharacterType fir::BoxCharType::getEleTy() const {
|
||||
|
@ -963,7 +961,7 @@ CharacterType fir::BoxCharType::getEleTy() const {
|
|||
// BoxProc<T>
|
||||
|
||||
BoxProcType fir::BoxProcType::get(mlir::Type elementType) {
|
||||
return Base::get(elementType.getContext(), FIR_BOXPROC, elementType);
|
||||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
mlir::Type fir::BoxProcType::getEleTy() const {
|
||||
|
@ -984,7 +982,7 @@ fir::BoxProcType::verifyConstructionInvariants(mlir::Location loc,
|
|||
// Reference<T>
|
||||
|
||||
ReferenceType fir::ReferenceType::get(mlir::Type elementType) {
|
||||
return Base::get(elementType.getContext(), FIR_REFERENCE, elementType);
|
||||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
mlir::Type fir::ReferenceType::getEleTy() const {
|
||||
|
@ -1005,7 +1003,7 @@ fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
|
|||
|
||||
PointerType fir::PointerType::get(mlir::Type elementType) {
|
||||
assert(singleIndirectionLevel(elementType) && "invalid element type");
|
||||
return Base::get(elementType.getContext(), FIR_POINTER, elementType);
|
||||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
mlir::Type fir::PointerType::getEleTy() const {
|
||||
|
@ -1033,7 +1031,7 @@ fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
|
|||
|
||||
HeapType fir::HeapType::get(mlir::Type elementType) {
|
||||
assert(singleIndirectionLevel(elementType) && "invalid element type");
|
||||
return Base::get(elementType.getContext(), FIR_HEAP, elementType);
|
||||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
mlir::Type fir::HeapType::getEleTy() const {
|
||||
|
@ -1054,7 +1052,7 @@ fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
|
|||
SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType,
|
||||
mlir::AffineMapAttr map) {
|
||||
auto *ctxt = elementType.getContext();
|
||||
return Base::get(ctxt, FIR_SEQUENCE, shape, elementType, map);
|
||||
return Base::get(ctxt, shape, elementType, map);
|
||||
}
|
||||
|
||||
mlir::Type fir::SequenceType::getEleTy() const {
|
||||
|
@ -1136,7 +1134,7 @@ llvm::hash_code fir::hash_value(const SequenceType::Shape &sh) {
|
|||
/// This type captures a Fortran "derived type"
|
||||
|
||||
RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) {
|
||||
return Base::get(ctxt, FIR_DERIVED, name);
|
||||
return Base::get(ctxt, name);
|
||||
}
|
||||
|
||||
void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
|
||||
|
@ -1179,7 +1177,7 @@ mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
|
|||
|
||||
TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
|
||||
assert(!ofType.isa<ReferenceType>());
|
||||
return Base::get(ofType.getContext(), FIR_TYPEDESC, ofType);
|
||||
return Base::get(ofType.getContext(), ofType);
|
||||
}
|
||||
|
||||
mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
|
||||
|
@ -1222,9 +1220,7 @@ void fir::verifyIntegralType(mlir::Type type) {
|
|||
void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
||||
mlir::DialectAsmPrinter &p) {
|
||||
auto &os = p.getStream();
|
||||
switch (ty.getKind()) {
|
||||
case fir::FIR_BOX: {
|
||||
auto type = ty.cast<BoxType>();
|
||||
if (auto type = ty.dyn_cast<BoxType>()) {
|
||||
os << "box<";
|
||||
p.printType(type.getEleTy());
|
||||
if (auto map = type.getLayoutMap()) {
|
||||
|
@ -1232,24 +1228,28 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
|||
p.printAttribute(map);
|
||||
}
|
||||
os << '>';
|
||||
} break;
|
||||
case fir::FIR_BOXCHAR: {
|
||||
auto type = ty.cast<BoxCharType>().getEleTy();
|
||||
os << "boxchar<" << type.cast<fir::CharacterType>().getFKind() << '>';
|
||||
} break;
|
||||
case fir::FIR_BOXPROC:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<BoxCharType>()) {
|
||||
os << "boxchar<" << type.getEleTy().cast<fir::CharacterType>().getFKind()
|
||||
<< '>';
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<BoxProcType>()) {
|
||||
os << "boxproc<";
|
||||
p.printType(ty.cast<BoxProcType>().getEleTy());
|
||||
p.printType(type.getEleTy());
|
||||
os << '>';
|
||||
break;
|
||||
case fir::FIR_CHARACTER: // intrinsic
|
||||
os << "char<" << ty.cast<CharacterType>().getFKind() << '>';
|
||||
break;
|
||||
case fir::FIR_COMPLEX: // intrinsic
|
||||
os << "complex<" << ty.cast<CplxType>().getFKind() << '>';
|
||||
break;
|
||||
case fir::FIR_DERIVED: { // derived
|
||||
auto type = ty.cast<fir::RecordType>();
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<CharacterType>()) {
|
||||
os << "char<" << type.getFKind() << '>';
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<CplxType>()) {
|
||||
os << "complex<" << type.getFKind() << '>';
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<RecordType>()) {
|
||||
os << "type<" << type.getName();
|
||||
if (!recordTypeVisited.count(type.uniqueKey())) {
|
||||
recordTypeVisited.insert(type.uniqueKey());
|
||||
|
@ -1274,43 +1274,52 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
|||
recordTypeVisited.erase(type.uniqueKey());
|
||||
}
|
||||
os << '>';
|
||||
} break;
|
||||
case fir::FIR_DIMS:
|
||||
os << "dims<" << ty.cast<DimsType>().getRank() << '>';
|
||||
break;
|
||||
case fir::FIR_FIELD:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<DimsType>()) {
|
||||
os << "dims<" << type.getRank() << '>';
|
||||
return;
|
||||
}
|
||||
if (ty.isa<FieldType>()) {
|
||||
os << "field";
|
||||
break;
|
||||
case fir::FIR_HEAP:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<HeapType>()) {
|
||||
os << "heap<";
|
||||
p.printType(ty.cast<HeapType>().getEleTy());
|
||||
p.printType(type.getEleTy());
|
||||
os << '>';
|
||||
break;
|
||||
case fir::FIR_INT: // intrinsic
|
||||
os << "int<" << ty.cast<fir::IntType>().getFKind() << '>';
|
||||
break;
|
||||
case fir::FIR_LEN:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<fir::IntType>()) {
|
||||
os << "int<" << type.getFKind() << '>';
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<LenType>()) {
|
||||
os << "len";
|
||||
break;
|
||||
case fir::FIR_LOGICAL: // intrinsic
|
||||
os << "logical<" << ty.cast<LogicalType>().getFKind() << '>';
|
||||
break;
|
||||
case fir::FIR_POINTER:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<LogicalType>()) {
|
||||
os << "logical<" << type.getFKind() << '>';
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<PointerType>()) {
|
||||
os << "ptr<";
|
||||
p.printType(ty.cast<PointerType>().getEleTy());
|
||||
p.printType(type.getEleTy());
|
||||
os << '>';
|
||||
break;
|
||||
case fir::FIR_REAL: // intrinsic
|
||||
os << "real<" << ty.cast<fir::RealType>().getFKind() << '>';
|
||||
break;
|
||||
case fir::FIR_REFERENCE:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<fir::RealType>()) {
|
||||
os << "real<" << type.getFKind() << '>';
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<ReferenceType>()) {
|
||||
os << "ref<";
|
||||
p.printType(ty.cast<ReferenceType>().getEleTy());
|
||||
p.printType(type.getEleTy());
|
||||
os << '>';
|
||||
break;
|
||||
case fir::FIR_SEQUENCE: {
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<SequenceType>()) {
|
||||
os << "array";
|
||||
auto type = ty.cast<SequenceType>();
|
||||
auto shape = type.getShape();
|
||||
if (shape.size()) {
|
||||
printBounds(os, shape);
|
||||
|
@ -1323,11 +1332,12 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
|||
map.print(os);
|
||||
}
|
||||
os << '>';
|
||||
} break;
|
||||
case fir::FIR_TYPEDESC:
|
||||
return;
|
||||
}
|
||||
if (auto type = ty.dyn_cast<TypeDescType>()) {
|
||||
os << "tdesc<";
|
||||
p.printType(ty.cast<TypeDescType>().getOfTy());
|
||||
p.printType(type.getOfTy());
|
||||
os << '>';
|
||||
break;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -190,11 +190,10 @@ public:
|
|||
assert(!elementTypes.empty() && "expected at least 1 element type");
|
||||
|
||||
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
||||
// of this type. The first two parameters are the context to unique in and
|
||||
// the kind of the type. The parameters after the type kind are forwarded to
|
||||
// the storage instance.
|
||||
// of this type. The first parameter is the context to unique in. The
|
||||
// parameters after the type kind are forwarded to the storage instance.
|
||||
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
||||
return Base::get(ctx, ToyTypes::Struct, elementTypes);
|
||||
return Base::get(ctx, elementTypes);
|
||||
}
|
||||
|
||||
/// Returns the element types of this struct type.
|
||||
|
|
|
@ -63,13 +63,6 @@ public:
|
|||
// Toy Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create a local enumeration with all of the types that are defined by Toy.
|
||||
namespace ToyTypes {
|
||||
enum Types {
|
||||
Struct = mlir::Type::FIRST_TOY_TYPE,
|
||||
};
|
||||
} // end namespace ToyTypes
|
||||
|
||||
/// This class defines the Toy struct type. It represents a collection of
|
||||
/// element types. All derived types in MLIR must inherit from the CRTP class
|
||||
/// 'Type::TypeBase'. It takes as template parameters the concrete type
|
||||
|
|
|
@ -474,11 +474,10 @@ StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
|
|||
assert(!elementTypes.empty() && "expected at least 1 element type");
|
||||
|
||||
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
||||
// of this type. The first two parameters are the context to unique in and the
|
||||
// kind of the type. The parameters after the type kind are forwarded to the
|
||||
// storage instance.
|
||||
// of this type. The first parameter is the context to unique in. The
|
||||
// parameters after the type kind are forwarded to the storage instance.
|
||||
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
||||
return Base::get(ctx, ToyTypes::Struct, elementTypes);
|
||||
return Base::get(ctx, elementTypes);
|
||||
}
|
||||
|
||||
/// Returns the element types of this struct type.
|
||||
|
|
|
@ -64,34 +64,6 @@ class LLVMIntegerType;
|
|||
/// structs, the entire type is the identifier) and are thread-safe.
|
||||
class LLVMType : public Type {
|
||||
public:
|
||||
enum Kind {
|
||||
// Keep non-parametric types contiguous in the enum.
|
||||
VoidType = FIRST_LLVM_TYPE + 1,
|
||||
HalfType,
|
||||
BFloatType,
|
||||
FloatType,
|
||||
DoubleType,
|
||||
FP128Type,
|
||||
X86FP80Type,
|
||||
PPCFP128Type,
|
||||
X86MMXType,
|
||||
LabelType,
|
||||
TokenType,
|
||||
MetadataType,
|
||||
// End of non-parametric types.
|
||||
FunctionType,
|
||||
IntegerType,
|
||||
PointerType,
|
||||
FixedVectorType,
|
||||
ScalableVectorType,
|
||||
ArrayType,
|
||||
StructType,
|
||||
FIRST_NEW_LLVM_TYPE = VoidType,
|
||||
LAST_NEW_LLVM_TYPE = StructType,
|
||||
FIRST_TRIVIAL_TYPE = VoidType,
|
||||
LAST_TRIVIAL_TYPE = MetadataType
|
||||
};
|
||||
|
||||
/// Inherit base constructors.
|
||||
using Type::Type;
|
||||
|
||||
|
@ -256,27 +228,24 @@ public:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Batch-define trivial types.
|
||||
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \
|
||||
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
|
||||
class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
|
||||
public: \
|
||||
using Base::Base; \
|
||||
static ClassName get(MLIRContext *context) { \
|
||||
return Base::get(context, Kind); \
|
||||
} \
|
||||
}
|
||||
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMType::VoidType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMType::HalfType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMType::BFloatType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMType::FloatType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMType::DoubleType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMType::FP128Type);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMType::X86FP80Type);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMType::PPCFP128Type);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMType::X86MMXType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMType::TokenType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMType::LabelType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMType::MetadataType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
|
||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
|
||||
|
||||
#undef DEFINE_TRIVIAL_LLVM_TYPE
|
||||
|
||||
|
|
|
@ -16,11 +16,6 @@ namespace mlir {
|
|||
class MLIRContext;
|
||||
|
||||
namespace linalg {
|
||||
enum LinalgTypes {
|
||||
Range = Type::FIRST_LINALG_TYPE,
|
||||
LAST_USED_LINALG_TYPE = Range,
|
||||
};
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||
|
||||
/// A RangeType represents a minimal range abstraction (min, max, step).
|
||||
|
@ -36,11 +31,6 @@ class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
|
|||
public:
|
||||
// Used for generic hooks in TypeBase.
|
||||
using Base::Base;
|
||||
/// Construction hook.
|
||||
static RangeType get(MLIRContext *context) {
|
||||
/// Custom, uniq'ed construction in the MLIRContext.
|
||||
return Base::get(context, LinalgTypes::Range);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace linalg
|
||||
|
|
|
@ -31,15 +31,6 @@ struct UniformQuantizedPerAxisTypeStorage;
|
|||
|
||||
} // namespace detail
|
||||
|
||||
namespace QuantizationTypes {
|
||||
enum Kind {
|
||||
Any = Type::FIRST_QUANTIZATION_TYPE,
|
||||
UniformQuantized,
|
||||
UniformQuantizedPerAxis,
|
||||
LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
|
||||
};
|
||||
} // namespace QuantizationTypes
|
||||
|
||||
/// Enumeration of bit-mapped flags related to quantized types.
|
||||
namespace QuantizationFlags {
|
||||
enum FlagValue {
|
||||
|
|
|
@ -32,15 +32,6 @@ struct TargetEnvAttributeStorage;
|
|||
struct VerCapExtAttributeStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// SPIR-V dialect-specific attribute kinds.
|
||||
namespace AttrKind {
|
||||
enum Kind {
|
||||
InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
|
||||
TargetEnv, /// Target environment
|
||||
VerCapExt, /// (version, extension, capability) triple
|
||||
};
|
||||
} // namespace AttrKind
|
||||
|
||||
/// An attribute that specifies the information regarding the interface
|
||||
/// variable: descriptor set, binding, storage class.
|
||||
class InterfaceVarABIAttr
|
||||
|
|
|
@ -65,19 +65,6 @@ struct StructTypeStorage;
|
|||
|
||||
} // namespace detail
|
||||
|
||||
namespace TypeKind {
|
||||
enum Kind {
|
||||
Array = Type::FIRST_SPIRV_TYPE,
|
||||
CooperativeMatrix,
|
||||
Image,
|
||||
Matrix,
|
||||
Pointer,
|
||||
RuntimeArray,
|
||||
Struct,
|
||||
LAST_SPIRV_TYPE = Struct,
|
||||
};
|
||||
}
|
||||
|
||||
// Base SPIR-V type for providing availability queries.
|
||||
class SPIRVType : public Type {
|
||||
public:
|
||||
|
|
|
@ -29,56 +29,28 @@ namespace shape {
|
|||
/// Alias type for extent tensors.
|
||||
RankedTensorType getExtentTensorType(MLIRContext *ctx);
|
||||
|
||||
namespace ShapeTypes {
|
||||
enum Kind {
|
||||
Component = Type::FIRST_SHAPE_TYPE,
|
||||
Element,
|
||||
Shape,
|
||||
Size,
|
||||
ValueShape,
|
||||
Witness,
|
||||
LAST_SHAPE_TYPE = Witness
|
||||
};
|
||||
} // namespace ShapeTypes
|
||||
|
||||
/// The component type corresponding to shape, element type and attribute.
|
||||
class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static ComponentType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::Component);
|
||||
}
|
||||
};
|
||||
|
||||
/// The element type of the shaped type.
|
||||
class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static ElementType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::Element);
|
||||
}
|
||||
};
|
||||
|
||||
/// The shape descriptor type represents rank and dimension sizes.
|
||||
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static ShapeType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::Shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// The type of a single dimension.
|
||||
class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static SizeType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::Size);
|
||||
}
|
||||
};
|
||||
|
||||
/// The ValueShape represents a (potentially unknown) runtime value and shape.
|
||||
|
@ -86,10 +58,6 @@ class ValueShapeType
|
|||
: public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static ValueShapeType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::ValueShape);
|
||||
}
|
||||
};
|
||||
|
||||
/// The Witness represents a runtime constraint, to be used as shape related
|
||||
|
@ -97,10 +65,6 @@ public:
|
|||
class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static WitnessType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::Witness);
|
||||
}
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -137,15 +137,23 @@ namespace detail {
|
|||
// MLIRContext. This class manages all creation and uniquing of attributes.
|
||||
class AttributeUniquer {
|
||||
public:
|
||||
/// Get an uniqued instance of attribute T.
|
||||
/// Get an uniqued instance of a parametric attribute T.
|
||||
template <typename T, typename... Args>
|
||||
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
||||
static typename std::enable_if_t<
|
||||
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>
|
||||
get(MLIRContext *ctx, Args &&...args) {
|
||||
return ctx->getAttributeUniquer().get<typename T::ImplType>(
|
||||
T::getTypeID(),
|
||||
[ctx](AttributeStorage *storage) {
|
||||
initializeAttributeStorage(storage, ctx, T::getTypeID());
|
||||
},
|
||||
kind, std::forward<Args>(args)...);
|
||||
T::getTypeID(), std::forward<Args>(args)...);
|
||||
}
|
||||
/// Get an uniqued instance of a singleton attribute T.
|
||||
template <typename T>
|
||||
static typename std::enable_if_t<
|
||||
std::is_same<typename T::ImplType, AttributeStorage>::value, T>
|
||||
get(MLIRContext *ctx) {
|
||||
return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
|
@ -156,6 +164,26 @@ public:
|
|||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Register a parametric attribute instance T with the uniquer.
|
||||
template <typename T>
|
||||
static typename std::enable_if_t<
|
||||
!std::is_same<typename T::ImplType, AttributeStorage>::value>
|
||||
registerAttribute(MLIRContext *ctx) {
|
||||
ctx->getAttributeUniquer()
|
||||
.registerParametricStorageType<typename T::ImplType>(T::getTypeID());
|
||||
}
|
||||
/// Register a singleton attribute instance T with the uniquer.
|
||||
template <typename T>
|
||||
static typename std::enable_if_t<
|
||||
std::is_same<typename T::ImplType, AttributeStorage>::value>
|
||||
registerAttribute(MLIRContext *ctx) {
|
||||
ctx->getAttributeUniquer()
|
||||
.registerSingletonStorageType<typename T::ImplType>(
|
||||
T::getTypeID(), [ctx](AttributeStorage *storage) {
|
||||
initializeAttributeStorage(storage, ctx, T::getTypeID());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
/// Initialize the given attribute storage instance.
|
||||
static void initializeAttributeStorage(AttributeStorage *storage,
|
||||
|
|
|
@ -54,14 +54,6 @@ struct SparseElementsAttributeStorage;
|
|||
/// passed by value.
|
||||
class Attribute {
|
||||
public:
|
||||
/// Integer identifier for all the concrete attribute kinds.
|
||||
enum Kind {
|
||||
// Reserve attribute kinds for dialect specific extensions.
|
||||
#define DEFINE_SYM_KIND_RANGE(Dialect) \
|
||||
FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
|
||||
#include "DialectSymbolRegistry.def"
|
||||
};
|
||||
|
||||
/// Utility class for implementing attributes.
|
||||
template <typename ConcreteType, typename BaseType, typename StorageType,
|
||||
template <typename T> class... Traits>
|
||||
|
@ -94,9 +86,6 @@ public:
|
|||
// Support dyn_cast'ing Attribute to itself.
|
||||
static bool classof(Attribute) { return true; }
|
||||
|
||||
/// Return the classification for this attribute.
|
||||
unsigned getKind() const { return impl->getKind(); }
|
||||
|
||||
/// Return a unique identifier for the concrete attribute type. This is used
|
||||
/// to support dynamic type casting.
|
||||
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
|
||||
|
@ -173,54 +162,6 @@ private:
|
|||
friend InterfaceBase;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StandardAttributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace StandardAttributes {
|
||||
enum Kind {
|
||||
AffineMap = Attribute::FIRST_STANDARD_ATTR,
|
||||
Array,
|
||||
Dictionary,
|
||||
Float,
|
||||
Integer,
|
||||
IntegerSet,
|
||||
Opaque,
|
||||
String,
|
||||
SymbolRef,
|
||||
Type,
|
||||
Unit,
|
||||
|
||||
/// Elements Attributes.
|
||||
DenseIntOrFPElements,
|
||||
DenseStringElements,
|
||||
OpaqueElements,
|
||||
SparseElements,
|
||||
FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
|
||||
LAST_ELEMENTS_ATTR = SparseElements,
|
||||
|
||||
/// Locations.
|
||||
CallSiteLocation,
|
||||
FileLineColLocation,
|
||||
FusedLocation,
|
||||
NameLocation,
|
||||
OpaqueLocation,
|
||||
UnknownLocation,
|
||||
|
||||
// Represents a location as a 'void*' pointer to a front-end's opaque
|
||||
// location information, which must live longer than the MLIR objects that
|
||||
// refer to it. OpaqueLocation's are never serialized.
|
||||
//
|
||||
// TODO: OpaqueLocation,
|
||||
|
||||
// Represents a value inlined through a function call.
|
||||
// TODO: InlinedLocation,
|
||||
|
||||
FIRST_LOCATION_ATTR = CallSiteLocation,
|
||||
LAST_LOCATION_ATTR = UnknownLocation,
|
||||
};
|
||||
} // namespace StandardAttributes
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineMapAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -154,21 +154,15 @@ protected:
|
|||
|
||||
void addOperation(AbstractOperation opInfo);
|
||||
|
||||
/// This method is used by derived classes to add their types to the set.
|
||||
/// Register a set of type classes with this dialect.
|
||||
template <typename... Args> void addTypes() {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (addType(Args::getTypeID(), AbstractType::get<Args>(*this)), 0)...};
|
||||
(void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
|
||||
}
|
||||
void addType(TypeID typeID, AbstractType &&typeInfo);
|
||||
|
||||
/// This method is used by derived classes to add their attributes to the set.
|
||||
/// Register a set of attribute classes with this dialect.
|
||||
template <typename... Args> void addAttributes() {
|
||||
(void)std::initializer_list<int>{
|
||||
0,
|
||||
(addAttribute(Args::getTypeID(), AbstractAttribute::get<Args>(*this)),
|
||||
0)...};
|
||||
(void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
|
||||
}
|
||||
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
|
||||
|
||||
/// Enable support for unregistered operations.
|
||||
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
|
||||
|
@ -189,6 +183,22 @@ private:
|
|||
Dialect(const Dialect &) = delete;
|
||||
void operator=(Dialect &) = delete;
|
||||
|
||||
/// Register an attribute instance with this dialect.
|
||||
template <typename T> void addAttribute() {
|
||||
// Add this attribute to the dialect and register it with the uniquer.
|
||||
addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
|
||||
detail::AttributeUniquer::registerAttribute<T>(context);
|
||||
}
|
||||
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
|
||||
|
||||
/// Register a type instance with this dialect.
|
||||
template <typename T> void addType() {
|
||||
// Add this type to the dialect and register it with the uniquer.
|
||||
addType(T::getTypeID(), AbstractType::get<T>(*this));
|
||||
detail::TypeUniquer::registerType<T>(context);
|
||||
}
|
||||
void addType(TypeID typeID, AbstractType &&typeInfo);
|
||||
|
||||
/// The namespace of this dialect.
|
||||
StringRef name;
|
||||
|
||||
|
|
|
@ -1,44 +0,0 @@
|
|||
//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file enumerates the different dialects that define custom classes
|
||||
// within the attribute or type system.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DEFINE_SYM_KIND_RANGE(STANDARD)
|
||||
DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
|
||||
DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
|
||||
DEFINE_SYM_KIND_RANGE(TENSORFLOW)
|
||||
DEFINE_SYM_KIND_RANGE(LLVM)
|
||||
DEFINE_SYM_KIND_RANGE(QUANTIZATION)
|
||||
DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
|
||||
DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
|
||||
DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
|
||||
DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
|
||||
DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
|
||||
DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
|
||||
DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
|
||||
DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect
|
||||
DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect
|
||||
DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect
|
||||
|
||||
// The following ranges are reserved for experimenting with MLIR dialects in a
|
||||
// private context without having to register them here.
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
|
||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
|
||||
|
||||
#undef DEFINE_SYM_KIND_RANGE
|
|
@ -756,7 +756,7 @@ public:
|
|||
/// all attributes of the given kind in the form : <alias>[0-9]+. These
|
||||
/// aliases must not contain `.`.
|
||||
virtual void getAttributeKindAliases(
|
||||
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
|
||||
SmallVectorImpl<std::pair<TypeID, StringRef>> &aliases) const {}
|
||||
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
|
||||
/// end with a numeric digit([0-9]+).
|
||||
virtual void getAttributeAliases(
|
||||
|
|
|
@ -38,33 +38,6 @@ struct TupleTypeStorage;
|
|||
|
||||
} // namespace detail
|
||||
|
||||
namespace StandardTypes {
|
||||
enum Kind {
|
||||
// Floating point.
|
||||
BF16 = Type::Kind::FIRST_STANDARD_TYPE,
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
FIRST_FLOATING_POINT_TYPE = BF16,
|
||||
LAST_FLOATING_POINT_TYPE = F64,
|
||||
|
||||
// Target pointer sized integer, used (e.g.) in affine mappings.
|
||||
Index,
|
||||
|
||||
// Derived types.
|
||||
Integer,
|
||||
Vector,
|
||||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
UnrankedMemRef,
|
||||
Complex,
|
||||
Tuple,
|
||||
None,
|
||||
};
|
||||
|
||||
} // namespace StandardTypes
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ComplexType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -82,29 +82,29 @@ public:
|
|||
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Get or create a new ConcreteT instance within the ctx. This
|
||||
/// function is guaranteed to return a non null object and will assert if
|
||||
/// the arguments provided are invalid.
|
||||
template <typename... Args>
|
||||
static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
|
||||
static ConcreteT get(MLIRContext *ctx, Args... args) {
|
||||
// Ensure that the invariants are correct for construction.
|
||||
assert(succeeded(ConcreteT::verifyConstructionInvariants(
|
||||
generateUnknownStorageLocation(ctx), args...)));
|
||||
return UniquerT::template get<ConcreteT>(ctx, kind, args...);
|
||||
return UniquerT::template get<ConcreteT>(ctx, args...);
|
||||
}
|
||||
|
||||
/// Get or create a new ConcreteT instance within the ctx, defined at
|
||||
/// the given, potentially unknown, location. If the arguments provided are
|
||||
/// invalid then emit errors and return a null object.
|
||||
template <typename LocationT, typename... Args>
|
||||
static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
|
||||
static ConcreteT getChecked(LocationT loc, Args... args) {
|
||||
// If the construction invariants fail then we return a null attribute.
|
||||
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
|
||||
return ConcreteT();
|
||||
return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
|
||||
return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Mutate the current storage instance. This will not change the unique key.
|
||||
/// The arguments are forwarded to 'ConcreteT::mutate'.
|
||||
template <typename... Args> LogicalResult mutate(Args &&...args) {
|
||||
|
|
|
@ -121,15 +121,23 @@ namespace detail {
|
|||
/// A utility class to get, or create, unique instances of types within an
|
||||
/// MLIRContext. This class manages all creation and uniquing of types.
|
||||
struct TypeUniquer {
|
||||
/// Get an uniqued instance of a type T.
|
||||
/// Get an uniqued instance of a parametric type T.
|
||||
template <typename T, typename... Args>
|
||||
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
||||
static typename std::enable_if_t<
|
||||
!std::is_same<typename T::ImplType, TypeStorage>::value, T>
|
||||
get(MLIRContext *ctx, Args &&...args) {
|
||||
return ctx->getTypeUniquer().get<typename T::ImplType>(
|
||||
T::getTypeID(),
|
||||
[&](TypeStorage *storage) {
|
||||
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
|
||||
},
|
||||
kind, std::forward<Args>(args)...);
|
||||
T::getTypeID(), std::forward<Args>(args)...);
|
||||
}
|
||||
/// Get an uniqued instance of a singleton type T.
|
||||
template <typename T>
|
||||
static typename std::enable_if_t<
|
||||
std::is_same<typename T::ImplType, TypeStorage>::value, T>
|
||||
get(MLIRContext *ctx) {
|
||||
return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
|
||||
}
|
||||
|
||||
/// Change the mutable component of the given type instance in the provided
|
||||
|
@ -141,6 +149,25 @@ struct TypeUniquer {
|
|||
return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Register a parametric type instance T with the uniquer.
|
||||
template <typename T>
|
||||
static typename std::enable_if_t<
|
||||
!std::is_same<typename T::ImplType, TypeStorage>::value>
|
||||
registerType(MLIRContext *ctx) {
|
||||
ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
|
||||
T::getTypeID());
|
||||
}
|
||||
/// Register a singleton type instance T with the uniquer.
|
||||
template <typename T>
|
||||
static typename std::enable_if_t<
|
||||
std::is_same<typename T::ImplType, TypeStorage>::value>
|
||||
registerType(MLIRContext *ctx) {
|
||||
ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
|
||||
T::getTypeID(), [&](TypeStorage *storage) {
|
||||
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
|
|
@ -34,11 +34,11 @@ struct OpaqueTypeStorage;
|
|||
///
|
||||
/// Some types are "primitives" meaning they do not have any parameters, for
|
||||
/// example the Index type. Parametric types have additional information that
|
||||
/// differentiates the types of the same kind between them, for example the
|
||||
/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
|
||||
/// different instances of the IntegerType. Type parameters are part of the
|
||||
/// unique immutable key. The mutable component of the type can be modified
|
||||
/// after the type is created, but cannot affect the identity of the type.
|
||||
/// differentiates the types of the same class, for example the Integer type has
|
||||
/// bitwidth, making i8 and i16 belong to the same kind by be different
|
||||
/// instances of the IntegerType. Type parameters are part of the unique
|
||||
/// immutable key. The mutable component of the type can be modified after the
|
||||
/// type is created, but cannot affect the identity of the type.
|
||||
///
|
||||
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
|
||||
///
|
||||
|
@ -53,20 +53,19 @@ struct OpaqueTypeStorage;
|
|||
/// * This method is expected to return failure if a type cannot be
|
||||
/// constructed with 'args', success otherwise.
|
||||
/// * 'args' must correspond with the arguments passed into the
|
||||
/// 'TypeBase::get' call after the type kind.
|
||||
/// 'TypeBase::get' call.
|
||||
///
|
||||
///
|
||||
/// Type storage objects inherit from TypeStorage and contain the following:
|
||||
/// - The type kind (for LLVM-style RTTI).
|
||||
/// - The dialect that defined the type.
|
||||
/// - Any parameters of the type.
|
||||
/// - An optional mutable component.
|
||||
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
|
||||
/// Parametric storage types must derive TypeStorage and respect the following:
|
||||
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
||||
/// instance of the type within its kind.
|
||||
/// instance of the type.
|
||||
/// * The key type must be constructible from the values passed into the
|
||||
/// detail::TypeUniquer::get call after the type kind.
|
||||
/// detail::TypeUniquer::get call.
|
||||
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
|
||||
/// storage class must define a hashing method:
|
||||
/// 'static unsigned hashKey(const KeyTy &)'
|
||||
|
@ -84,23 +83,6 @@ struct OpaqueTypeStorage;
|
|||
// the key.
|
||||
class Type {
|
||||
public:
|
||||
/// Integer identifier for all the concrete type kinds.
|
||||
/// Note: This is not an enum class as each dialect will likely define a
|
||||
/// separate enumeration for the specific types that they define. Not being an
|
||||
/// enum class also simplifies the handling of type kinds by not requiring
|
||||
/// casts for each use.
|
||||
enum Kind {
|
||||
// Builtin types.
|
||||
Function,
|
||||
Opaque,
|
||||
LAST_BUILTIN_TYPE = Opaque,
|
||||
|
||||
// Reserve type kinds for dialect specific type system extensions.
|
||||
#define DEFINE_SYM_KIND_RANGE(Dialect) \
|
||||
FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
|
||||
#include "DialectSymbolRegistry.def"
|
||||
};
|
||||
|
||||
/// Utility class for implementing types.
|
||||
template <typename ConcreteType, typename BaseType, typename StorageType,
|
||||
template <typename T> class... Traits>
|
||||
|
@ -136,9 +118,6 @@ public:
|
|||
/// dynamic type casting.
|
||||
TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
|
||||
|
||||
/// Return the classification for this type.
|
||||
unsigned getKind() const;
|
||||
|
||||
/// Return the LLVMContext in which this type was uniqued.
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
|
|
|
@ -11,12 +11,11 @@
|
|||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
||||
namespace mlir {
|
||||
class TypeID;
|
||||
|
||||
namespace detail {
|
||||
struct StorageUniquerImpl;
|
||||
|
||||
|
@ -29,22 +28,19 @@ template <typename ImplTy, typename T>
|
|||
using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
|
||||
} // namespace detail
|
||||
|
||||
/// A utility class to get, or create instances of storage classes. These
|
||||
/// storage classes must respect the following constraints:
|
||||
/// - Derive from StorageUniquer::BaseStorage.
|
||||
/// - Provide an unsigned 'kind' value to be used as part of the unique'ing
|
||||
/// process.
|
||||
/// A utility class to get or create instances of "storage classes". These
|
||||
/// storage classes must derive from 'StorageUniquer::BaseStorage'.
|
||||
///
|
||||
/// For non-parametric storage classes, i.e. those that are solely uniqued by
|
||||
/// their kind, nothing else is needed. Instances of these classes can be
|
||||
/// created by calling `get` without trailing arguments.
|
||||
/// For non-parametric storage classes, i.e. singleton classes, nothing else is
|
||||
/// needed. Instances of these classes can be created by calling `get` without
|
||||
/// trailing arguments.
|
||||
///
|
||||
/// Otherwise, the parametric storage classes may be created with `get`,
|
||||
/// and must respect the following:
|
||||
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
||||
/// instance of the storage class within its kind.
|
||||
/// instance of the storage class.
|
||||
/// * The key type must be constructible from the values passed into the
|
||||
/// getComplex call after the kind.
|
||||
/// getComplex call.
|
||||
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
|
||||
/// storage class must define a hashing method:
|
||||
/// 'static unsigned hashKey(const KeyTy &)'
|
||||
|
@ -83,32 +79,11 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
|
|||
/// class.
|
||||
class StorageUniquer {
|
||||
public:
|
||||
StorageUniquer();
|
||||
~StorageUniquer();
|
||||
|
||||
/// Set the flag specifying if multi-threading is disabled within the uniquer.
|
||||
void disableMultithreading(bool disable = true);
|
||||
|
||||
/// Register a new storage object with this uniquer using the given unique
|
||||
/// type id.
|
||||
void registerStorageType(TypeID id);
|
||||
|
||||
/// This class acts as the base storage that all storage classes must derived
|
||||
/// from.
|
||||
class BaseStorage {
|
||||
public:
|
||||
/// Get the kind classification of this storage.
|
||||
unsigned getKind() const { return kind; }
|
||||
|
||||
protected:
|
||||
BaseStorage() : kind(0) {}
|
||||
|
||||
private:
|
||||
/// Allow access to the kind field.
|
||||
friend detail::StorageUniquerImpl;
|
||||
|
||||
/// Classification of the subclass, used for type checking.
|
||||
unsigned kind;
|
||||
BaseStorage() = default;
|
||||
};
|
||||
|
||||
/// This is a utility allocator used to allocate memory for instances of
|
||||
|
@ -145,19 +120,61 @@ public:
|
|||
llvm::BumpPtrAllocator allocator;
|
||||
};
|
||||
|
||||
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
|
||||
/// that can be used to initialize a newly inserted storage instance. This
|
||||
/// function is used for derived types that have complex storage or uniquing
|
||||
/// constraints.
|
||||
template <typename Storage, typename Arg, typename... Args>
|
||||
Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
|
||||
unsigned kind, Arg &&arg, Args &&...args) {
|
||||
// Construct a value of the derived key type.
|
||||
auto derivedKey =
|
||||
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
|
||||
StorageUniquer();
|
||||
~StorageUniquer();
|
||||
|
||||
// Create a hash of the kind and the derived key.
|
||||
unsigned hashValue = getHash<Storage>(kind, derivedKey);
|
||||
/// Set the flag specifying if multi-threading is disabled within the uniquer.
|
||||
void disableMultithreading(bool disable = true);
|
||||
|
||||
/// Register a new parametric storage class, this is necessary to create
|
||||
/// instances of this class type. `id` is the type identifier that will be
|
||||
/// used to identify this type when creating instances of it via 'get'.
|
||||
template <typename Storage> void registerParametricStorageType(TypeID id) {
|
||||
registerParametricStorageTypeImpl(id);
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage> void registerParametricStorageType() {
|
||||
registerParametricStorageType<Storage>(TypeID::get<Storage>());
|
||||
}
|
||||
/// Register a new singleton storage class, this is necessary to get the
|
||||
/// singletone instance. `id` is the type identifier that will be used to
|
||||
/// access the singleton instance via 'get'. An optional initialization
|
||||
/// function may also be provided to initialize the newly created storage
|
||||
/// instance, and used when the singleton instance is created.
|
||||
template <typename Storage>
|
||||
void registerSingletonStorageType(TypeID id,
|
||||
function_ref<void(Storage *)> initFn) {
|
||||
auto ctorFn = [&](StorageAllocator &allocator) {
|
||||
auto *storage = new (allocator.allocate<Storage>()) Storage();
|
||||
if (initFn)
|
||||
initFn(storage);
|
||||
return storage;
|
||||
};
|
||||
registerSingletonImpl(id, ctorFn);
|
||||
}
|
||||
template <typename Storage> void registerSingletonStorageType(TypeID id) {
|
||||
registerSingletonStorageType<Storage>(id, llvm::None);
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage>
|
||||
void registerSingletonStorageType(
|
||||
function_ref<void(Storage *)> initFn = llvm::None) {
|
||||
registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
|
||||
}
|
||||
|
||||
/// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
|
||||
/// registering the storage instance. 'initFn' is an optional parameter that
|
||||
/// can be used to initialize a newly inserted storage instance. This function
|
||||
/// is used for derived types that have complex storage or uniquing
|
||||
/// constraints.
|
||||
template <typename Storage, typename... Args>
|
||||
Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
|
||||
Args &&...args) {
|
||||
// Construct a value of the derived key type.
|
||||
auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
|
||||
|
||||
// Create a hash of the derived key.
|
||||
unsigned hashValue = getHash<Storage>(derivedKey);
|
||||
|
||||
// Generate an equality function for the derived storage.
|
||||
auto isEqual = [&derivedKey](const BaseStorage *existing) {
|
||||
|
@ -174,29 +191,29 @@ public:
|
|||
|
||||
// Get an instance for the derived storage.
|
||||
return static_cast<Storage *>(
|
||||
getImpl(id, kind, hashValue, isEqual, ctorFn));
|
||||
getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage, typename... Args>
|
||||
Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
|
||||
return get<Storage>(initFn, TypeID::get<Storage>(),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
|
||||
/// that can be used to initialize a newly inserted storage instance. This
|
||||
/// function is used for derived types that use no additional storage or
|
||||
/// uniquing outside of the kind.
|
||||
template <typename Storage>
|
||||
Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
|
||||
unsigned kind) {
|
||||
auto ctorFn = [&](StorageAllocator &allocator) {
|
||||
auto *storage = new (allocator.allocate<Storage>()) Storage();
|
||||
if (initFn)
|
||||
initFn(storage);
|
||||
return storage;
|
||||
};
|
||||
return static_cast<Storage *>(getImpl(id, kind, ctorFn));
|
||||
/// Gets a uniqued instance of 'Storage' which is a singleton storage type.
|
||||
/// 'id' is the type id used when registering the storage instance.
|
||||
template <typename Storage> Storage *get(TypeID id) {
|
||||
return static_cast<Storage *>(getSingletonImpl(id));
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage> Storage *get() {
|
||||
return get<Storage>(TypeID::get<Storage>());
|
||||
}
|
||||
|
||||
/// Changes the mutable component of 'storage' by forwarding the trailing
|
||||
/// arguments to the 'mutate' function of the derived class.
|
||||
template <typename Storage, typename... Args>
|
||||
LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) {
|
||||
LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
|
||||
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
|
||||
return static_cast<Storage &>(*storage).mutate(
|
||||
allocator, std::forward<Args>(args)...);
|
||||
|
@ -207,13 +224,13 @@ public:
|
|||
/// Erases a uniqued instance of 'Storage'. This function is used for derived
|
||||
/// types that have complex storage or uniquing constraints.
|
||||
template <typename Storage, typename Arg, typename... Args>
|
||||
void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) {
|
||||
void erase(TypeID id, Arg &&arg, Args &&...args) {
|
||||
// Construct a value of the derived key type.
|
||||
auto derivedKey =
|
||||
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
|
||||
|
||||
// Create a hash of the kind and the derived key.
|
||||
unsigned hashValue = getHash<Storage>(kind, derivedKey);
|
||||
// Create a hash of the derived key.
|
||||
unsigned hashValue = getHash<Storage>(derivedKey);
|
||||
|
||||
// Generate an equality function for the derived storage.
|
||||
auto isEqual = [&derivedKey](const BaseStorage *existing) {
|
||||
|
@ -221,32 +238,42 @@ public:
|
|||
};
|
||||
|
||||
// Attempt to erase the storage instance.
|
||||
eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) {
|
||||
eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) {
|
||||
static_cast<Storage *>(storage)->cleanup();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// complex storage.
|
||||
BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue,
|
||||
/// parametric storage.
|
||||
BaseStorage *getParametricStorageTypeImpl(
|
||||
TypeID id, unsigned hashValue,
|
||||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
||||
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// default storage.
|
||||
BaseStorage *getImpl(const TypeID &id, unsigned kind,
|
||||
/// Implementation for registering an instance of a derived type with
|
||||
/// parametric storage.
|
||||
void registerParametricStorageTypeImpl(TypeID id);
|
||||
|
||||
/// Implementation for getting an instance of a derived type with default
|
||||
/// storage.
|
||||
BaseStorage *getSingletonImpl(TypeID id);
|
||||
|
||||
/// Implementation for registering an instance of a derived type with default
|
||||
/// storage.
|
||||
void
|
||||
registerSingletonImpl(TypeID id,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
||||
|
||||
/// Implementation for erasing an instance of a derived type with complex
|
||||
/// storage.
|
||||
void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue,
|
||||
void eraseImpl(TypeID id, unsigned hashValue,
|
||||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<void(BaseStorage *)> cleanupFn);
|
||||
|
||||
/// Implementation for mutating an instance of a derived storage.
|
||||
LogicalResult
|
||||
mutateImpl(const TypeID &id,
|
||||
mutateImpl(TypeID id,
|
||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn);
|
||||
|
||||
/// The internal implementation class.
|
||||
|
@ -276,27 +303,26 @@ private:
|
|||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Key and Kind Hashing
|
||||
// Key Hashing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
|
||||
/// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
|
||||
/// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if
|
||||
/// there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
|
||||
template <typename ImplTy, typename DerivedKey>
|
||||
static typename std::enable_if<
|
||||
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
|
||||
::llvm::hash_code>::type
|
||||
getHash(unsigned kind, const DerivedKey &derivedKey) {
|
||||
return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
|
||||
getHash(const DerivedKey &derivedKey) {
|
||||
return ImplTy::hashKey(derivedKey);
|
||||
}
|
||||
/// If there is no 'ImplTy::hashKey' default to using the
|
||||
/// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
|
||||
/// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo'
|
||||
/// definition for 'DerivedKey' for generating a hash.
|
||||
template <typename ImplTy, typename DerivedKey>
|
||||
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
|
||||
ImplTy, DerivedKey>::value,
|
||||
::llvm::hash_code>::type
|
||||
getHash(unsigned kind, const DerivedKey &derivedKey) {
|
||||
return llvm::hash_combine(
|
||||
kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
|
||||
getHash(const DerivedKey &derivedKey) {
|
||||
return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
|
||||
}
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -264,14 +264,13 @@ bool LLVMArrayType::isValidElementType(LLVMType type) {
|
|||
|
||||
LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
|
||||
assert(elementType && "expected non-null subtype");
|
||||
return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType,
|
||||
numElements);
|
||||
return Base::get(elementType.getContext(), elementType, numElements);
|
||||
}
|
||||
|
||||
LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
|
||||
unsigned numElements) {
|
||||
assert(elementType && "expected non-null subtype");
|
||||
return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements);
|
||||
return Base::getChecked(loc, elementType, numElements);
|
||||
}
|
||||
|
||||
LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
|
||||
|
@ -301,16 +300,14 @@ LLVMFunctionType LLVMFunctionType::get(LLVMType result,
|
|||
ArrayRef<LLVMType> arguments,
|
||||
bool isVarArg) {
|
||||
assert(result && "expected non-null result");
|
||||
return Base::get(result.getContext(), LLVMType::FunctionType, result,
|
||||
arguments, isVarArg);
|
||||
return Base::get(result.getContext(), result, arguments, isVarArg);
|
||||
}
|
||||
|
||||
LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
|
||||
ArrayRef<LLVMType> arguments,
|
||||
bool isVarArg) {
|
||||
assert(result && "expected non-null result");
|
||||
return Base::getChecked(loc, LLVMType::FunctionType, result, arguments,
|
||||
isVarArg);
|
||||
return Base::getChecked(loc, result, arguments, isVarArg);
|
||||
}
|
||||
|
||||
LLVMType LLVMFunctionType::getReturnType() {
|
||||
|
@ -347,11 +344,11 @@ LogicalResult LLVMFunctionType::verifyConstructionInvariants(
|
|||
// Integer type.
|
||||
|
||||
LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
|
||||
return Base::get(ctx, LLVMType::IntegerType, bitwidth);
|
||||
return Base::get(ctx, bitwidth);
|
||||
}
|
||||
|
||||
LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
|
||||
return Base::getChecked(loc, LLVMType::IntegerType, bitwidth);
|
||||
return Base::getChecked(loc, bitwidth);
|
||||
}
|
||||
|
||||
unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
|
||||
|
@ -374,13 +371,12 @@ bool LLVMPointerType::isValidElementType(LLVMType type) {
|
|||
|
||||
LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
|
||||
assert(pointee && "expected non-null subtype");
|
||||
return Base::get(pointee.getContext(), LLVMType::PointerType, pointee,
|
||||
addressSpace);
|
||||
return Base::get(pointee.getContext(), pointee, addressSpace);
|
||||
}
|
||||
|
||||
LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
|
||||
unsigned addressSpace) {
|
||||
return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace);
|
||||
return Base::getChecked(loc, pointee, addressSpace);
|
||||
}
|
||||
|
||||
LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
|
||||
|
@ -405,32 +401,32 @@ bool LLVMStructType::isValidElementType(LLVMType type) {
|
|||
|
||||
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
|
||||
StringRef name) {
|
||||
return Base::get(context, LLVMType::StructType, name, /*opaque=*/false);
|
||||
return Base::get(context, name, /*opaque=*/false);
|
||||
}
|
||||
|
||||
LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
|
||||
StringRef name) {
|
||||
return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false);
|
||||
return Base::getChecked(loc, name, /*opaque=*/false);
|
||||
}
|
||||
|
||||
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
|
||||
ArrayRef<LLVMType> types,
|
||||
bool isPacked) {
|
||||
return Base::get(context, LLVMType::StructType, types, isPacked);
|
||||
return Base::get(context, types, isPacked);
|
||||
}
|
||||
|
||||
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
|
||||
ArrayRef<LLVMType> types,
|
||||
bool isPacked) {
|
||||
return Base::getChecked(loc, LLVMType::StructType, types, isPacked);
|
||||
return Base::getChecked(loc, types, isPacked);
|
||||
}
|
||||
|
||||
LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
|
||||
return Base::get(context, LLVMType::StructType, name, /*opaque=*/true);
|
||||
return Base::get(context, name, /*opaque=*/true);
|
||||
}
|
||||
|
||||
LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
|
||||
return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true);
|
||||
return Base::getChecked(loc, name, /*opaque=*/true);
|
||||
}
|
||||
|
||||
LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
|
||||
|
@ -508,16 +504,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
|
|||
LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
|
||||
unsigned numElements) {
|
||||
assert(elementType && "expected non-null subtype");
|
||||
return Base::get(elementType.getContext(), LLVMType::FixedVectorType,
|
||||
elementType, numElements);
|
||||
return Base::get(elementType.getContext(), elementType, numElements);
|
||||
}
|
||||
|
||||
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
|
||||
LLVMType elementType,
|
||||
unsigned numElements) {
|
||||
assert(elementType && "expected non-null subtype");
|
||||
return Base::getChecked(loc, LLVMType::FixedVectorType, elementType,
|
||||
numElements);
|
||||
return Base::getChecked(loc, elementType, numElements);
|
||||
}
|
||||
|
||||
unsigned LLVMFixedVectorType::getNumElements() {
|
||||
|
@ -527,16 +521,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
|
|||
LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
|
||||
unsigned minNumElements) {
|
||||
assert(elementType && "expected non-null subtype");
|
||||
return Base::get(elementType.getContext(), LLVMType::ScalableVectorType,
|
||||
elementType, minNumElements);
|
||||
return Base::get(elementType.getContext(), elementType, minNumElements);
|
||||
}
|
||||
|
||||
LLVMScalableVectorType
|
||||
LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
|
||||
unsigned minNumElements) {
|
||||
assert(elementType && "expected non-null subtype");
|
||||
return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType,
|
||||
minNumElements);
|
||||
return Base::getChecked(loc, elementType, minNumElements);
|
||||
}
|
||||
|
||||
unsigned LLVMScalableVectorType::getMinNumElements() {
|
||||
|
|
|
@ -204,8 +204,8 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
|
|||
Type expressedType,
|
||||
int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
|
||||
storageType, expressedType, storageTypeMin, storageTypeMax);
|
||||
return Base::get(storageType.getContext(), flags, storageType, expressedType,
|
||||
storageTypeMin, storageTypeMax);
|
||||
}
|
||||
|
||||
AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
|
||||
|
@ -213,8 +213,8 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
|
|||
int64_t storageTypeMin,
|
||||
int64_t storageTypeMax,
|
||||
Location location) {
|
||||
return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
|
||||
expressedType, storageTypeMin, storageTypeMax);
|
||||
return Base::getChecked(location, flags, storageType, expressedType,
|
||||
storageTypeMin, storageTypeMax);
|
||||
}
|
||||
|
||||
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
|
||||
|
@ -240,10 +240,8 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
|
|||
int64_t zeroPoint,
|
||||
int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
return Base::get(storageType.getContext(),
|
||||
QuantizationTypes::UniformQuantized, flags, storageType,
|
||||
expressedType, scale, zeroPoint, storageTypeMin,
|
||||
storageTypeMax);
|
||||
return Base::get(storageType.getContext(), flags, storageType, expressedType,
|
||||
scale, zeroPoint, storageTypeMin, storageTypeMax);
|
||||
}
|
||||
|
||||
UniformQuantizedType
|
||||
|
@ -251,9 +249,8 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
|
|||
Type expressedType, double scale,
|
||||
int64_t zeroPoint, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax, Location location) {
|
||||
return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
|
||||
storageType, expressedType, scale, zeroPoint,
|
||||
storageTypeMin, storageTypeMax);
|
||||
return Base::getChecked(location, flags, storageType, expressedType, scale,
|
||||
zeroPoint, storageTypeMin, storageTypeMax);
|
||||
}
|
||||
|
||||
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
|
||||
|
@ -295,10 +292,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
|
|||
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
||||
int32_t quantizedDimension, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
return Base::get(storageType.getContext(),
|
||||
QuantizationTypes::UniformQuantizedPerAxis, flags,
|
||||
storageType, expressedType, scales, zeroPoints,
|
||||
quantizedDimension, storageTypeMin, storageTypeMax);
|
||||
return Base::get(storageType.getContext(), flags, storageType, expressedType,
|
||||
scales, zeroPoints, quantizedDimension, storageTypeMin,
|
||||
storageTypeMax);
|
||||
}
|
||||
|
||||
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
||||
|
@ -306,9 +302,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
|||
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
||||
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
|
||||
Location location) {
|
||||
return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
|
||||
flags, storageType, expressedType, scales, zeroPoints,
|
||||
quantizedDimension, storageTypeMin, storageTypeMax);
|
||||
return Base::getChecked(location, flags, storageType, expressedType, scales,
|
||||
zeroPoints, quantizedDimension, storageTypeMin,
|
||||
storageTypeMax);
|
||||
}
|
||||
|
||||
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
|
||||
|
|
|
@ -13,11 +13,11 @@ using namespace mlir;
|
|||
|
||||
SDBMDialect::SDBMDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
|
||||
uniquer.registerStorageType(TypeID::get<detail::SDBMBinaryExprStorage>());
|
||||
uniquer.registerStorageType(TypeID::get<detail::SDBMConstantExprStorage>());
|
||||
uniquer.registerStorageType(TypeID::get<detail::SDBMDiffExprStorage>());
|
||||
uniquer.registerStorageType(TypeID::get<detail::SDBMNegExprStorage>());
|
||||
uniquer.registerStorageType(TypeID::get<detail::SDBMTermExprStorage>());
|
||||
uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
|
||||
uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
|
||||
}
|
||||
|
||||
SDBMDialect::~SDBMDialect() = default;
|
||||
|
|
|
@ -246,7 +246,6 @@ SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
|
|||
|
||||
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
||||
TypeID::get<detail::SDBMBinaryExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
||||
}
|
||||
|
||||
|
@ -533,9 +532,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
|
|||
assert(rhs && "expected SDBM dimension");
|
||||
|
||||
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMDiffExprStorage>(
|
||||
TypeID::get<detail::SDBMDiffExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
|
||||
return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
|
||||
}
|
||||
|
||||
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
||||
|
@ -575,7 +572,6 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
|
|||
|
||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
||||
TypeID::get<detail::SDBMBinaryExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
|
||||
stripeFactor);
|
||||
}
|
||||
|
@ -611,8 +607,7 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
|
|||
|
||||
StorageUniquer &uniquer = dialect->getUniquer();
|
||||
return uniquer.get<detail::SDBMTermExprStorage>(
|
||||
TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
|
||||
static_cast<unsigned>(SDBMExprKind::DimId), position);
|
||||
assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -628,8 +623,7 @@ SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
|
|||
|
||||
StorageUniquer &uniquer = dialect->getUniquer();
|
||||
return uniquer.get<detail::SDBMTermExprStorage>(
|
||||
TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
|
||||
static_cast<unsigned>(SDBMExprKind::SymbolId), position);
|
||||
assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -644,9 +638,7 @@ SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
|
|||
};
|
||||
|
||||
StorageUniquer &uniquer = dialect->getUniquer();
|
||||
return uniquer.get<detail::SDBMConstantExprStorage>(
|
||||
TypeID::get<detail::SDBMConstantExprStorage>(), assignCtx,
|
||||
static_cast<unsigned>(SDBMExprKind::Constant), value);
|
||||
return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
|
||||
}
|
||||
|
||||
int64_t SDBMConstantExpr::getValue() const {
|
||||
|
@ -661,9 +653,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
|
|||
assert(var && "expected non-null SDBM direct expression");
|
||||
|
||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||
return uniquer.get<detail::SDBMNegExprStorage>(
|
||||
TypeID::get<detail::SDBMNegExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
|
||||
return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
|
||||
}
|
||||
|
||||
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
||||
|
|
|
@ -25,27 +25,28 @@ namespace detail {
|
|||
|
||||
// Base storage class for SDBMExpr.
|
||||
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
|
||||
SDBMExprKind getKind() {
|
||||
return static_cast<SDBMExprKind>(BaseStorage::getKind());
|
||||
}
|
||||
SDBMExprKind getKind() { return kind; }
|
||||
|
||||
SDBMDialect *dialect;
|
||||
SDBMExprKind kind;
|
||||
};
|
||||
|
||||
// Storage class for SDBM sum and stripe expressions.
|
||||
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
|
||||
using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
||||
return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
|
||||
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
|
||||
}
|
||||
|
||||
static SDBMBinaryExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
|
||||
result->lhs = std::get<0>(key);
|
||||
result->rhs = std::get<1>(key);
|
||||
result->lhs = std::get<1>(key);
|
||||
result->rhs = std::get<2>(key);
|
||||
result->dialect = result->lhs.getDialect();
|
||||
result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -67,6 +68,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
|
|||
result->lhs = std::get<0>(key);
|
||||
result->rhs = std::get<1>(key);
|
||||
result->dialect = result->lhs.getDialect();
|
||||
result->kind = SDBMExprKind::Diff;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -84,6 +86,7 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
|
|||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMConstantExprStorage>();
|
||||
result->constant = key;
|
||||
result->kind = SDBMExprKind::Constant;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -92,14 +95,18 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
|
|||
|
||||
// Storage class for SDBM dimension and symbol expressions.
|
||||
struct SDBMTermExprStorage : public SDBMExprStorage {
|
||||
using KeyTy = unsigned;
|
||||
using KeyTy = std::pair<unsigned, unsigned>;
|
||||
|
||||
bool operator==(const KeyTy &key) const { return position == key; }
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return kind == static_cast<SDBMExprKind>(key.first) &&
|
||||
position == key.second;
|
||||
}
|
||||
|
||||
static SDBMTermExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<SDBMTermExprStorage>();
|
||||
result->position = key;
|
||||
result->kind = static_cast<SDBMExprKind>(key.first);
|
||||
result->position = key.second;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -117,6 +124,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage {
|
|||
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
||||
result->expr = key;
|
||||
result->dialect = key.getDialect();
|
||||
result->kind = SDBMExprKind::Neg;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -120,8 +120,7 @@ spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
|
|||
IntegerAttr storageClass) {
|
||||
assert(descriptorSet && binding);
|
||||
MLIRContext *context = descriptorSet.getContext();
|
||||
return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
|
||||
binding, storageClass);
|
||||
return Base::get(context, descriptorSet, binding, storageClass);
|
||||
}
|
||||
|
||||
StringRef spirv::InterfaceVarABIAttr::getKindName() {
|
||||
|
@ -195,8 +194,7 @@ spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
|
|||
ArrayAttr extensions) {
|
||||
assert(version && capabilities && extensions);
|
||||
MLIRContext *context = version.getContext();
|
||||
return Base::get(context, spirv::AttrKind::VerCapExt, version, capabilities,
|
||||
extensions);
|
||||
return Base::get(context, version, capabilities, extensions);
|
||||
}
|
||||
|
||||
StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
|
||||
|
@ -272,7 +270,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
|
|||
DictionaryAttr limits) {
|
||||
assert(triple && limits && "expected valid triple and limits");
|
||||
MLIRContext *context = triple.getContext();
|
||||
return Base::get(context, spirv::AttrKind::TargetEnv, triple, limits);
|
||||
return Base::get(context, triple, limits);
|
||||
}
|
||||
|
||||
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
|
||||
|
|
|
@ -124,15 +124,14 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
|
|||
|
||||
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
|
||||
assert(elementCount && "ArrayType needs at least one element");
|
||||
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
|
||||
elementCount, /*stride=*/0);
|
||||
return Base::get(elementType.getContext(), elementType, elementCount,
|
||||
/*stride=*/0);
|
||||
}
|
||||
|
||||
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
|
||||
unsigned stride) {
|
||||
assert(elementCount && "ArrayType needs at least one element");
|
||||
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
|
||||
elementCount, stride);
|
||||
return Base::get(elementType.getContext(), elementType, elementCount, stride);
|
||||
}
|
||||
|
||||
unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
|
||||
|
@ -285,8 +284,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
|
|||
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
|
||||
Scope scope, unsigned rows,
|
||||
unsigned columns) {
|
||||
return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
|
||||
elementType, scope, rows, columns);
|
||||
return Base::get(elementType.getContext(), elementType, scope, rows, columns);
|
||||
}
|
||||
|
||||
Type CooperativeMatrixNVType::getElementType() const {
|
||||
|
@ -389,7 +387,7 @@ ImageType
|
|||
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
|
||||
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
|
||||
value) {
|
||||
return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
|
||||
return Base::get(std::get<0>(value).getContext(), value);
|
||||
}
|
||||
|
||||
Type ImageType::getElementType() const { return getImpl()->elementType; }
|
||||
|
@ -453,8 +451,7 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
|
|||
};
|
||||
|
||||
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
|
||||
return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType,
|
||||
storageClass);
|
||||
return Base::get(pointeeType.getContext(), pointeeType, storageClass);
|
||||
}
|
||||
|
||||
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
|
||||
|
@ -511,13 +508,11 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
|
|||
};
|
||||
|
||||
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
|
||||
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
|
||||
elementType, /*stride=*/0);
|
||||
return Base::get(elementType.getContext(), elementType, /*stride=*/0);
|
||||
}
|
||||
|
||||
RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
|
||||
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
|
||||
elementType, stride);
|
||||
return Base::get(elementType.getContext(), elementType, stride);
|
||||
}
|
||||
|
||||
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
|
||||
|
@ -846,12 +841,12 @@ StructType::get(ArrayRef<Type> memberTypes,
|
|||
SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
|
||||
memberDecorations.begin(), memberDecorations.end());
|
||||
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
|
||||
return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
|
||||
memberTypes, offsetInfo, sortedDecorations);
|
||||
return Base::get(memberTypes.vec().front().getContext(), memberTypes,
|
||||
offsetInfo, sortedDecorations);
|
||||
}
|
||||
|
||||
StructType StructType::getEmpty(MLIRContext *context) {
|
||||
return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
|
||||
return Base::get(context, ArrayRef<Type>(),
|
||||
ArrayRef<StructType::OffsetInfo>(),
|
||||
ArrayRef<StructType::MemberDecorationInfo>());
|
||||
}
|
||||
|
@ -946,13 +941,12 @@ struct spirv::detail::MatrixTypeStorage : public TypeStorage {
|
|||
};
|
||||
|
||||
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
|
||||
return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
|
||||
columnCount);
|
||||
return Base::get(columnType.getContext(), columnType, columnCount);
|
||||
}
|
||||
|
||||
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
|
||||
Location location) {
|
||||
return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
|
||||
return Base::getChecked(location, columnType, columnCount);
|
||||
}
|
||||
|
||||
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
|
||||
|
|
|
@ -20,9 +20,7 @@ using namespace mlir::detail;
|
|||
|
||||
MLIRContext *AffineExpr::getContext() const { return expr->context; }
|
||||
|
||||
AffineExprKind AffineExpr::getKind() const {
|
||||
return static_cast<AffineExprKind>(expr->getKind());
|
||||
}
|
||||
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
|
||||
|
||||
/// Walk all of the AffineExprs in this subgraph in postorder.
|
||||
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
|
||||
|
@ -449,8 +447,7 @@ static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
|
|||
|
||||
StorageUniquer &uniquer = context->getAffineUniquer();
|
||||
return uniquer.get<AffineDimExprStorage>(
|
||||
TypeID::get<AffineDimExprStorage>(), assignCtx,
|
||||
static_cast<unsigned>(kind), position);
|
||||
assignCtx, static_cast<unsigned>(kind), position);
|
||||
}
|
||||
|
||||
AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
|
||||
|
@ -484,9 +481,7 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
|
|||
};
|
||||
|
||||
StorageUniquer &uniquer = context->getAffineUniquer();
|
||||
return uniquer.get<AffineConstantExprStorage>(
|
||||
TypeID::get<AffineConstantExprStorage>(), assignCtx,
|
||||
static_cast<unsigned>(AffineExprKind::Constant), constant);
|
||||
return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
|
||||
}
|
||||
|
||||
/// Simplify add expression. Return nullptr if it can't be simplified.
|
||||
|
@ -594,7 +589,6 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
|
|||
|
||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
|
||||
}
|
||||
|
||||
|
@ -655,7 +649,6 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
|
|||
|
||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
|
||||
}
|
||||
|
||||
|
@ -722,7 +715,6 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
|
|||
|
||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
|
||||
other);
|
||||
}
|
||||
|
@ -766,7 +758,6 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
|
|||
|
||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
|
||||
other);
|
||||
}
|
||||
|
@ -814,7 +805,6 @@ AffineExpr AffineExpr::operator%(AffineExpr other) const {
|
|||
|
||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,21 +27,24 @@ namespace detail {
|
|||
/// Base storage class appearing in an affine expression.
|
||||
struct AffineExprStorage : public StorageUniquer::BaseStorage {
|
||||
MLIRContext *context;
|
||||
AffineExprKind kind;
|
||||
};
|
||||
|
||||
/// A binary operation appearing in an affine expression.
|
||||
struct AffineBinaryOpExprStorage : public AffineExprStorage {
|
||||
using KeyTy = std::pair<AffineExpr, AffineExpr>;
|
||||
using KeyTy = std::tuple<unsigned, AffineExpr, AffineExpr>;
|
||||
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key.first == lhs && key.second == rhs;
|
||||
return static_cast<AffineExprKind>(std::get<0>(key)) == kind &&
|
||||
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
|
||||
}
|
||||
|
||||
static AffineBinaryOpExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
|
||||
result->lhs = key.first;
|
||||
result->rhs = key.second;
|
||||
result->kind = static_cast<AffineExprKind>(std::get<0>(key));
|
||||
result->lhs = std::get<1>(key);
|
||||
result->rhs = std::get<2>(key);
|
||||
result->context = result->lhs.getContext();
|
||||
return result;
|
||||
}
|
||||
|
@ -52,14 +55,18 @@ struct AffineBinaryOpExprStorage : public AffineExprStorage {
|
|||
|
||||
/// A dimensional or symbolic identifier appearing in an affine expression.
|
||||
struct AffineDimExprStorage : public AffineExprStorage {
|
||||
using KeyTy = unsigned;
|
||||
using KeyTy = std::pair<unsigned, unsigned>;
|
||||
|
||||
bool operator==(const KeyTy &key) const { return position == key; }
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return kind == static_cast<AffineExprKind>(key.first) &&
|
||||
position == key.second;
|
||||
}
|
||||
|
||||
static AffineDimExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<AffineDimExprStorage>();
|
||||
result->position = key;
|
||||
result->kind = static_cast<AffineExprKind>(key.first);
|
||||
result->position = key.second;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -76,6 +83,7 @@ struct AffineConstantExprStorage : public AffineExprStorage {
|
|||
static AffineConstantExprStorage *
|
||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||
auto *result = allocator.allocate<AffineConstantExprStorage>();
|
||||
result->kind = AffineExprKind::Constant;
|
||||
result->constant = key;
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -271,7 +271,7 @@ private:
|
|||
/// Mapping between attribute kind and a pair comprised of a base alias name
|
||||
/// and a unique list of attributes belonging to this kind sorted by location
|
||||
/// seen in the module.
|
||||
llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
|
||||
llvm::MapVector<TypeID, std::pair<StringRef, std::vector<Attribute>>>
|
||||
attrKindToAlias;
|
||||
|
||||
/// Set of types known to be used within the module.
|
||||
|
@ -301,13 +301,13 @@ void AliasState::initialize(
|
|||
llvm::StringSet<> usedAliases;
|
||||
|
||||
// Collect the set of aliases from each dialect.
|
||||
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
|
||||
SmallVector<std::pair<TypeID, StringRef>, 8> attributeKindAliases;
|
||||
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
|
||||
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
|
||||
|
||||
// AffineMap/Integer set have specific kind aliases.
|
||||
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
|
||||
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
|
||||
attributeKindAliases.emplace_back(AffineMapAttr::getTypeID(), "map");
|
||||
attributeKindAliases.emplace_back(IntegerSetAttr::getTypeID(), "set");
|
||||
|
||||
for (auto &interface : interfaces) {
|
||||
interface.getAttributeKindAliases(attributeKindAliases);
|
||||
|
@ -317,7 +317,7 @@ void AliasState::initialize(
|
|||
|
||||
// Setup the attribute kind aliases.
|
||||
StringRef alias;
|
||||
unsigned attrKind;
|
||||
TypeID attrKind;
|
||||
for (auto &attrAliasPair : attributeKindAliases) {
|
||||
std::tie(attrKind, alias) = attrAliasPair;
|
||||
assert(!alias.empty() && "expected non-empty alias string");
|
||||
|
@ -420,7 +420,7 @@ void AliasState::recordAttributeReference(Attribute attr) {
|
|||
return;
|
||||
|
||||
// If this attribute kind has an alias, then record one for this attribute.
|
||||
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
|
||||
auto alias = attrKindToAlias.find(attr.getTypeID());
|
||||
if (alias == attrKindToAlias.end())
|
||||
return;
|
||||
std::pair<StringRef, int> attrAlias(alias->second.first,
|
||||
|
|
|
@ -57,7 +57,7 @@ Dialect &Attribute::getDialect() const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineMapAttr AffineMapAttr::get(AffineMap value) {
|
||||
return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
|
||||
return Base::get(value.getContext(), value);
|
||||
}
|
||||
|
||||
AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
|
||||
|
@ -67,7 +67,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
|
||||
return Base::get(context, StandardAttributes::Array, value);
|
||||
return Base::get(context, value);
|
||||
}
|
||||
|
||||
ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
|
||||
|
@ -156,7 +156,7 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
|
|||
if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
|
||||
value = storage;
|
||||
|
||||
return Base::get(context, StandardAttributes::Dictionary, value);
|
||||
return Base::get(context, value);
|
||||
}
|
||||
/// Construct a dictionary with an array of values that is known to already be
|
||||
/// sorted by name and uniqued.
|
||||
|
@ -175,7 +175,7 @@ DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
|
|||
return l.first == r.first;
|
||||
}) == value.end() &&
|
||||
"DictionaryAttr element names must be unique");
|
||||
return Base::get(context, StandardAttributes::Dictionary, value);
|
||||
return Base::get(context, value);
|
||||
}
|
||||
|
||||
ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
|
||||
|
@ -219,19 +219,19 @@ size_t DictionaryAttr::size() const { return getValue().size(); }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FloatAttr FloatAttr::get(Type type, double value) {
|
||||
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
|
||||
return Base::get(type.getContext(), type, value);
|
||||
}
|
||||
|
||||
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
|
||||
return Base::getChecked(loc, StandardAttributes::Float, type, value);
|
||||
return Base::getChecked(loc, type, value);
|
||||
}
|
||||
|
||||
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
|
||||
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
|
||||
return Base::get(type.getContext(), type, value);
|
||||
}
|
||||
|
||||
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
|
||||
return Base::getChecked(loc, StandardAttributes::Float, type, value);
|
||||
return Base::getChecked(loc, type, value);
|
||||
}
|
||||
|
||||
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
|
||||
|
@ -279,14 +279,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
|
||||
return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
|
||||
.cast<FlatSymbolRefAttr>();
|
||||
return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
|
||||
}
|
||||
|
||||
SymbolRefAttr SymbolRefAttr::get(StringRef value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences,
|
||||
MLIRContext *ctx) {
|
||||
return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
|
||||
return Base::get(ctx, value, nestedReferences);
|
||||
}
|
||||
|
||||
StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
|
||||
|
@ -307,7 +306,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
|
|||
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
|
||||
if (type.isSignlessInteger(1))
|
||||
return BoolAttr::get(value.getBoolValue(), type.getContext());
|
||||
return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
|
||||
return Base::get(type.getContext(), type, value);
|
||||
}
|
||||
|
||||
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
|
||||
|
@ -380,8 +379,7 @@ bool BoolAttr::classof(Attribute attr) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
|
||||
return Base::get(value.getConstraint(0).getContext(),
|
||||
StandardAttributes::IntegerSet, value);
|
||||
return Base::get(value.getConstraint(0).getContext(), value);
|
||||
}
|
||||
|
||||
IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
||||
|
@ -392,14 +390,12 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
|||
|
||||
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
|
||||
MLIRContext *context) {
|
||||
return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
|
||||
type);
|
||||
return Base::get(context, dialect, attrData, type);
|
||||
}
|
||||
|
||||
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
|
||||
Type type, Location location) {
|
||||
return Base::getChecked(location, StandardAttributes::Opaque, dialect,
|
||||
attrData, type);
|
||||
return Base::getChecked(location, dialect, attrData, type);
|
||||
}
|
||||
|
||||
/// Returns the dialect namespace of the opaque attribute.
|
||||
|
@ -430,7 +426,7 @@ StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
|
|||
|
||||
/// Get an instance of a StringAttr with the given string and Type.
|
||||
StringAttr StringAttr::get(StringRef bytes, Type type) {
|
||||
return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
|
||||
return Base::get(type.getContext(), bytes, type);
|
||||
}
|
||||
|
||||
StringRef StringAttr::getValue() const { return getImpl()->value; }
|
||||
|
@ -440,7 +436,7 @@ StringRef StringAttr::getValue() const { return getImpl()->value; }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TypeAttr TypeAttr::get(Type value) {
|
||||
return Base::get(value.getContext(), StandardAttributes::Type, value);
|
||||
return Base::get(value.getContext(), value);
|
||||
}
|
||||
|
||||
Type TypeAttr::getValue() const { return getImpl()->value; }
|
||||
|
@ -1036,8 +1032,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
|
|||
|
||||
DenseStringElementsAttr
|
||||
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
|
||||
return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
|
||||
type, values, (values.size() == 1));
|
||||
return Base::get(type.getContext(), type, values, (values.size() == 1));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1088,8 +1083,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
|
|||
assert((type.isa<RankedTensorType, VectorType>()) &&
|
||||
"type must be ranked tensor or vector");
|
||||
assert(type.hasStaticShape() && "type must have static shape");
|
||||
return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
|
||||
type, data, isSplat);
|
||||
return Base::get(type.getContext(), type, data, isSplat);
|
||||
}
|
||||
|
||||
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||
|
@ -1210,8 +1204,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
|
|||
StringRef bytes) {
|
||||
assert(TensorType::isValidElementType(type.getElementType()) &&
|
||||
"Input element type should be a valid tensor element type");
|
||||
return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
|
||||
dialect, bytes);
|
||||
return Base::get(type.getContext(), type, dialect, bytes);
|
||||
}
|
||||
|
||||
StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
|
||||
|
@ -1248,7 +1241,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
|
|||
assert((type.isa<RankedTensorType, VectorType>()) &&
|
||||
"type must be ranked tensor or vector");
|
||||
assert(type.hasStaticShape() && "type must have static shape");
|
||||
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
|
||||
return Base::get(type.getContext(), type,
|
||||
indices.cast<DenseIntElementsAttr>(), values);
|
||||
}
|
||||
|
||||
|
|
|
@ -28,8 +28,7 @@ bool LocationAttr::classof(Attribute attr) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Location CallSiteLoc::get(Location callee, Location caller) {
|
||||
return Base::get(callee->getContext(), StandardAttributes::CallSiteLocation,
|
||||
callee, caller);
|
||||
return Base::get(callee->getContext(), callee, caller);
|
||||
}
|
||||
|
||||
Location CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
|
||||
|
@ -50,8 +49,7 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
|
|||
|
||||
Location FileLineColLoc::get(Identifier filename, unsigned line,
|
||||
unsigned column, MLIRContext *context) {
|
||||
return Base::get(context, StandardAttributes::FileLineColLocation, filename,
|
||||
line, column);
|
||||
return Base::get(context, filename, line, column);
|
||||
}
|
||||
|
||||
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
|
||||
|
@ -95,7 +93,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
|
|||
return UnknownLoc::get(context);
|
||||
if (locs.size() == 1)
|
||||
return locs.front();
|
||||
return Base::get(context, StandardAttributes::FusedLocation, locs, metadata);
|
||||
return Base::get(context, locs, metadata);
|
||||
}
|
||||
|
||||
ArrayRef<Location> FusedLoc::getLocations() const {
|
||||
|
@ -111,8 +109,7 @@ Attribute FusedLoc::getMetadata() const { return getImpl()->metadata; }
|
|||
Location NameLoc::get(Identifier name, Location child) {
|
||||
assert(!child.isa<NameLoc>() &&
|
||||
"a NameLoc cannot be used as a child of another NameLoc");
|
||||
return Base::get(child->getContext(), StandardAttributes::NameLocation, name,
|
||||
child);
|
||||
return Base::get(child->getContext(), name, child);
|
||||
}
|
||||
|
||||
Location NameLoc::get(Identifier name, MLIRContext *context) {
|
||||
|
@ -131,9 +128,8 @@ Location NameLoc::getChildLoc() const { return getImpl()->child; }
|
|||
|
||||
Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
|
||||
Location fallbackLocation) {
|
||||
return Base::get(fallbackLocation->getContext(),
|
||||
StandardAttributes::OpaqueLocation, underlyingLocation,
|
||||
typeID, fallbackLocation);
|
||||
return Base::get(fallbackLocation->getContext(), underlyingLocation, typeID,
|
||||
fallbackLocation);
|
||||
}
|
||||
|
||||
uintptr_t OpaqueLoc::getUnderlyingLocation() const {
|
||||
|
|
|
@ -87,6 +87,10 @@ namespace {
|
|||
struct BuiltinDialect : public Dialect {
|
||||
BuiltinDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
|
||||
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
|
||||
FunctionType, IndexType, IntegerType, MemRefType,
|
||||
UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
|
||||
TupleType, UnrankedTensorType, VectorType>();
|
||||
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
|
||||
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
|
||||
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
|
||||
|
@ -95,11 +99,6 @@ struct BuiltinDialect : public Dialect {
|
|||
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
|
||||
UnknownLoc>();
|
||||
|
||||
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
|
||||
FunctionType, IndexType, IntegerType, MemRefType,
|
||||
UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
|
||||
TupleType, UnrankedTensorType, VectorType>();
|
||||
|
||||
// TODO: These operations should be moved to a different dialect when they
|
||||
// have been fully decoupled from the core.
|
||||
addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
|
||||
|
@ -363,56 +362,50 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
|
|||
|
||||
//// Types.
|
||||
/// Floating-point Types.
|
||||
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
|
||||
impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
|
||||
impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
|
||||
impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
|
||||
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
|
||||
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
|
||||
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
|
||||
impl->f64Ty = TypeUniquer::get<Float64Type>(this);
|
||||
/// Index Type.
|
||||
impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
|
||||
impl->indexTy = TypeUniquer::get<IndexType>(this);
|
||||
/// Integer Types.
|
||||
impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1,
|
||||
IntegerType::Signless);
|
||||
impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8,
|
||||
IntegerType::Signless);
|
||||
impl->int16Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
||||
16, IntegerType::Signless);
|
||||
impl->int32Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
||||
32, IntegerType::Signless);
|
||||
impl->int64Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
||||
64, IntegerType::Signless);
|
||||
impl->int128Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
||||
128, IntegerType::Signless);
|
||||
impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
|
||||
impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
|
||||
impl->int16Ty =
|
||||
TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
|
||||
impl->int32Ty =
|
||||
TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
|
||||
impl->int64Ty =
|
||||
TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
|
||||
impl->int128Ty =
|
||||
TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
|
||||
/// None Type.
|
||||
impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
|
||||
impl->noneType = TypeUniquer::get<NoneType>(this);
|
||||
|
||||
//// Attributes.
|
||||
//// Note: These must be registered after the types as they may generate one
|
||||
//// of the above types internally.
|
||||
/// Bool Attributes.
|
||||
impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
|
||||
this, StandardAttributes::Integer, impl->int1Ty,
|
||||
APInt(/*numBits=*/1, false))
|
||||
this, impl->int1Ty, APInt(/*numBits=*/1, false))
|
||||
.cast<BoolAttr>();
|
||||
impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
|
||||
this, StandardAttributes::Integer, impl->int1Ty,
|
||||
APInt(/*numBits=*/1, true))
|
||||
this, impl->int1Ty, APInt(/*numBits=*/1, true))
|
||||
.cast<BoolAttr>();
|
||||
/// Unit Attribute.
|
||||
impl->unitAttr =
|
||||
AttributeUniquer::get<UnitAttr>(this, StandardAttributes::Unit);
|
||||
impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
|
||||
/// Unknown Location Attribute.
|
||||
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
|
||||
this, StandardAttributes::UnknownLocation);
|
||||
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
|
||||
/// The empty dictionary attribute.
|
||||
impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
|
||||
this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
|
||||
impl->emptyDictionaryAttr =
|
||||
AttributeUniquer::get<DictionaryAttr>(this, ArrayRef<NamedAttribute>());
|
||||
|
||||
// Register the affine storage objects with the uniquer.
|
||||
impl->affineUniquer.registerStorageType(
|
||||
TypeID::get<AffineBinaryOpExprStorage>());
|
||||
impl->affineUniquer.registerStorageType(
|
||||
TypeID::get<AffineConstantExprStorage>());
|
||||
impl->affineUniquer.registerStorageType(TypeID::get<AffineDimExprStorage>());
|
||||
impl->affineUniquer
|
||||
.registerParametricStorageType<AffineBinaryOpExprStorage>();
|
||||
impl->affineUniquer
|
||||
.registerParametricStorageType<AffineConstantExprStorage>();
|
||||
impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
|
||||
}
|
||||
|
||||
MLIRContext::~MLIRContext() {}
|
||||
|
@ -582,7 +575,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
|||
AbstractType(std::move(typeInfo));
|
||||
if (!impl.registeredTypes.insert({typeID, newInfo}).second)
|
||||
llvm::report_fatal_error("Dialect Type already registered.");
|
||||
impl.typeUniquer.registerStorageType(typeID);
|
||||
}
|
||||
|
||||
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
||||
|
@ -592,7 +584,6 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
|||
AbstractAttribute(std::move(attrInfo));
|
||||
if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
|
||||
llvm::report_fatal_error("Dialect Attribute already registered.");
|
||||
impl.attributeUniquer.registerStorageType(typeID);
|
||||
}
|
||||
|
||||
/// Get the dialect that registered the attribute with the provided typeid.
|
||||
|
@ -718,7 +709,7 @@ IntegerType IntegerType::get(unsigned width,
|
|||
MLIRContext *context) {
|
||||
if (auto cached = getCachedIntegerType(width, signedness, context))
|
||||
return cached;
|
||||
return Base::get(context, StandardTypes::Integer, width, signedness);
|
||||
return Base::get(context, width, signedness);
|
||||
}
|
||||
|
||||
IntegerType IntegerType::getChecked(unsigned width, Location location) {
|
||||
|
@ -731,12 +722,16 @@ IntegerType IntegerType::getChecked(unsigned width,
|
|||
if (auto cached =
|
||||
getCachedIntegerType(width, signedness, location->getContext()))
|
||||
return cached;
|
||||
return Base::getChecked(location, StandardTypes::Integer, width, signedness);
|
||||
return Base::getChecked(location, width, signedness);
|
||||
}
|
||||
|
||||
/// Get an instance of the NoneType.
|
||||
NoneType NoneType::get(MLIRContext *context) {
|
||||
return context->getImpl().noneType;
|
||||
if (NoneType cachedInst = context->getImpl().noneType)
|
||||
return cachedInst;
|
||||
// Note: May happen when initializing the singleton attributes of the builtin
|
||||
// dialect.
|
||||
return Base::get(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -102,12 +102,11 @@ unsigned Type::getIntOrFloatBitWidth() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ComplexType ComplexType::get(Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::Complex,
|
||||
elementType);
|
||||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
||||
return Base::getChecked(location, StandardTypes::Complex, elementType);
|
||||
return Base::getChecked(location, elementType);
|
||||
}
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
|
@ -265,13 +264,12 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
|
||||
elementType);
|
||||
return Base::get(elementType.getContext(), shape, elementType);
|
||||
}
|
||||
|
||||
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||
Location location) {
|
||||
return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
|
||||
return Base::getChecked(location, shape, elementType);
|
||||
}
|
||||
|
||||
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
|
||||
|
@ -320,15 +318,13 @@ bool TensorType::isValidElementType(Type type) {
|
|||
|
||||
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
|
||||
elementType);
|
||||
return Base::get(elementType.getContext(), shape, elementType);
|
||||
}
|
||||
|
||||
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
||||
Type elementType,
|
||||
Location location) {
|
||||
return Base::getChecked(location, StandardTypes::RankedTensor, shape,
|
||||
elementType);
|
||||
return Base::getChecked(location, shape, elementType);
|
||||
}
|
||||
|
||||
LogicalResult RankedTensorType::verifyConstructionInvariants(
|
||||
|
@ -349,13 +345,12 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
|
||||
elementType);
|
||||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
||||
Location location) {
|
||||
return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
|
||||
return Base::getChecked(location, elementType);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -444,8 +439,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
|
|||
cleanedAffineMapComposition.push_back(map);
|
||||
}
|
||||
|
||||
return Base::get(context, StandardTypes::MemRef, shape, elementType,
|
||||
cleanedAffineMapComposition, memorySpace);
|
||||
return Base::get(context, shape, elementType, cleanedAffineMapComposition,
|
||||
memorySpace);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
|
||||
|
@ -462,15 +457,13 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
|
|||
|
||||
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
||||
unsigned memorySpace) {
|
||||
return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
|
||||
elementType, memorySpace);
|
||||
return Base::get(elementType.getContext(), elementType, memorySpace);
|
||||
}
|
||||
|
||||
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
||||
unsigned memorySpace,
|
||||
Location location) {
|
||||
return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
|
||||
memorySpace);
|
||||
return Base::getChecked(location, elementType, memorySpace);
|
||||
}
|
||||
|
||||
unsigned UnrankedMemRefType::getMemorySpace() const {
|
||||
|
@ -642,7 +635,7 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
|
|||
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||
/// arguments define a well-formed type.
|
||||
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
|
||||
return Base::get(context, StandardTypes::Tuple, elementTypes);
|
||||
return Base::get(context, elementTypes);
|
||||
}
|
||||
|
||||
/// Get or create an empty tuple type.
|
||||
|
|
|
@ -19,8 +19,6 @@ using namespace mlir::detail;
|
|||
// Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unsigned Type::getKind() const { return impl->getKind(); }
|
||||
|
||||
Dialect &Type::getDialect() const {
|
||||
return impl->getAbstractType().getDialect();
|
||||
}
|
||||
|
@ -33,7 +31,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
|||
|
||||
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
|
||||
MLIRContext *context) {
|
||||
return Base::get(context, Type::Kind::Function, inputs, results);
|
||||
return Base::get(context, inputs, results);
|
||||
}
|
||||
|
||||
unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
|
||||
|
@ -54,12 +52,12 @@ ArrayRef<Type> FunctionType::getResults() const {
|
|||
|
||||
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
|
||||
MLIRContext *context) {
|
||||
return Base::get(context, Type::Kind::Opaque, dialect, typeData);
|
||||
return Base::get(context, dialect, typeData);
|
||||
}
|
||||
|
||||
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
|
||||
MLIRContext *context, Location location) {
|
||||
return Base::getChecked(location, Kind::Opaque, dialect, typeData);
|
||||
return Base::getChecked(location, dialect, typeData);
|
||||
}
|
||||
|
||||
/// Returns the dialect namespace of the opaque type.
|
||||
|
|
|
@ -16,19 +16,17 @@ using namespace mlir;
|
|||
using namespace mlir::detail;
|
||||
|
||||
namespace {
|
||||
/// This class represents a uniquer for storage instances of a specific type. It
|
||||
/// contains all of the necessary data to unique storage instances in a thread
|
||||
/// safe way. This allows for the main uniquer to bucket each of the individual
|
||||
/// sub-types removing the need to lock the main uniquer itself.
|
||||
struct InstSpecificUniquer {
|
||||
/// This class represents a uniquer for storage instances of a specific type
|
||||
/// that has parametric storage. It contains all of the necessary data to unique
|
||||
/// storage instances in a thread safe way. This allows for the main uniquer to
|
||||
/// bucket each of the individual sub-types removing the need to lock the main
|
||||
/// uniquer itself.
|
||||
struct ParametricStorageUniquer {
|
||||
using BaseStorage = StorageUniquer::BaseStorage;
|
||||
using StorageAllocator = StorageUniquer::StorageAllocator;
|
||||
|
||||
/// A lookup key for derived instances of storage objects.
|
||||
struct LookupKey {
|
||||
/// The known derived kind for the storage.
|
||||
unsigned kind;
|
||||
|
||||
/// The known hash value of the key.
|
||||
unsigned hashValue;
|
||||
|
||||
|
@ -63,18 +61,14 @@ struct InstSpecificUniquer {
|
|||
static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
|
||||
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
|
||||
return false;
|
||||
// If the lookup kind matches the kind of the storage, then invoke the
|
||||
// equality function on the lookup key.
|
||||
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
|
||||
// Invoke the equality function on the lookup key.
|
||||
return lhs.isEqual(rhs.storage);
|
||||
}
|
||||
};
|
||||
|
||||
/// Unique types with specific hashing or storage constraints.
|
||||
/// The set containing the allocated storage instances.
|
||||
using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
|
||||
StorageTypeSet complexInstances;
|
||||
|
||||
/// Instances of this storage object.
|
||||
llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
|
||||
StorageTypeSet instances;
|
||||
|
||||
/// Allocator to use when constructing derived instances.
|
||||
StorageAllocator allocator;
|
||||
|
@ -91,107 +85,79 @@ struct StorageUniquerImpl {
|
|||
using BaseStorage = StorageUniquer::BaseStorage;
|
||||
using StorageAllocator = StorageUniquer::StorageAllocator;
|
||||
|
||||
/// Get or create an instance of a complex derived type.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Parametric Storage
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Get or create an instance of a parametric type.
|
||||
BaseStorage *
|
||||
getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
|
||||
getOrCreate(TypeID id, unsigned hashValue,
|
||||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||
assert(instUniquers.count(id) && "creating unregistered storage instance");
|
||||
InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
|
||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
||||
assert(parametricUniquers.count(id) &&
|
||||
"creating unregistered storage instance");
|
||||
ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
|
||||
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
|
||||
if (!threadingIsEnabled)
|
||||
return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
|
||||
return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
|
||||
|
||||
// Check for an existing instance in read-only mode.
|
||||
{
|
||||
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
|
||||
auto it = storageUniquer.complexInstances.find_as(lookupKey);
|
||||
if (it != storageUniquer.complexInstances.end())
|
||||
auto it = storageUniquer.instances.find_as(lookupKey);
|
||||
if (it != storageUniquer.instances.end())
|
||||
return it->storage;
|
||||
}
|
||||
|
||||
// Acquire a writer-lock so that we can safely create the new type instance.
|
||||
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
|
||||
return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
|
||||
return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
|
||||
}
|
||||
/// Get or create an instance of a complex derived type in an thread-unsafe
|
||||
/// fashion.
|
||||
BaseStorage *
|
||||
getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
|
||||
InstSpecificUniquer::LookupKey &lookupKey,
|
||||
getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer,
|
||||
ParametricStorageUniquer::LookupKey &lookupKey,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||
auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
|
||||
auto existing = storageUniquer.instances.insert_as({}, lookupKey);
|
||||
if (!existing.second)
|
||||
return existing.first->storage;
|
||||
|
||||
// Otherwise, construct and initialize the derived storage for this type
|
||||
// instance.
|
||||
BaseStorage *storage =
|
||||
initializeStorage(kind, storageUniquer.allocator, ctorFn);
|
||||
BaseStorage *storage = ctorFn(storageUniquer.allocator);
|
||||
*existing.first =
|
||||
InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
|
||||
ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage};
|
||||
return storage;
|
||||
}
|
||||
|
||||
/// Get or create an instance of a simple derived type.
|
||||
BaseStorage *
|
||||
getOrCreate(TypeID id, unsigned kind,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||
assert(instUniquers.count(id) && "creating unregistered storage instance");
|
||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
||||
if (!threadingIsEnabled)
|
||||
return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
|
||||
|
||||
// Check for an existing instance in read-only mode.
|
||||
{
|
||||
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
|
||||
auto it = storageUniquer.simpleInstances.find(kind);
|
||||
if (it != storageUniquer.simpleInstances.end())
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Acquire a writer-lock so that we can safely create the new type instance.
|
||||
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
|
||||
return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
|
||||
}
|
||||
/// Get or create an instance of a simple derived type in an thread-unsafe
|
||||
/// fashion.
|
||||
BaseStorage *
|
||||
getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||
auto &result = storageUniquer.simpleInstances[kind];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
// Otherwise, create and return a new storage instance.
|
||||
return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
|
||||
}
|
||||
|
||||
/// Erase an instance of a complex derived type.
|
||||
void erase(TypeID id, unsigned kind, unsigned hashValue,
|
||||
/// Erase an instance of a parametric derived type.
|
||||
void erase(TypeID id, unsigned hashValue,
|
||||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<void(BaseStorage *)> cleanupFn) {
|
||||
assert(instUniquers.count(id) && "erasing unregistered storage instance");
|
||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
||||
InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
|
||||
assert(parametricUniquers.count(id) &&
|
||||
"erasing unregistered storage instance");
|
||||
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
|
||||
ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
|
||||
|
||||
// Acquire a writer-lock so that we can safely erase the type instance.
|
||||
llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
|
||||
auto existing = storageUniquer.complexInstances.find_as(lookupKey);
|
||||
if (existing == storageUniquer.complexInstances.end())
|
||||
auto existing = storageUniquer.instances.find_as(lookupKey);
|
||||
if (existing == storageUniquer.instances.end())
|
||||
return;
|
||||
|
||||
// Cleanup the storage and remove it from the map.
|
||||
cleanupFn(existing->storage);
|
||||
storageUniquer.complexInstances.erase(existing);
|
||||
storageUniquer.instances.erase(existing);
|
||||
}
|
||||
|
||||
/// Mutates an instance of a derived storage in a thread-safe way.
|
||||
LogicalResult
|
||||
mutate(TypeID id,
|
||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||
assert(instUniquers.count(id) && "mutating unregistered storage instance");
|
||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
||||
assert(parametricUniquers.count(id) &&
|
||||
"mutating unregistered storage instance");
|
||||
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
|
||||
if (!threadingIsEnabled)
|
||||
return mutationFn(storageUniquer.allocator);
|
||||
|
||||
|
@ -199,21 +165,31 @@ struct StorageUniquerImpl {
|
|||
return mutationFn(storageUniquer.allocator);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Singleton Storage
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Get or create an instance of a singleton storage class.
|
||||
BaseStorage *getSingleton(TypeID id) {
|
||||
BaseStorage *singletonInstance = singletonInstances[id];
|
||||
assert(singletonInstance && "expected singleton instance to exist");
|
||||
return singletonInstance;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instance Storage
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Utility to create and initialize a storage instance.
|
||||
BaseStorage *
|
||||
initializeStorage(unsigned kind, StorageAllocator &allocator,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||
BaseStorage *storage = ctorFn(allocator);
|
||||
storage->kind = kind;
|
||||
return storage;
|
||||
}
|
||||
|
||||
/// Map of type ids to the storage uniquer to use for registered objects.
|
||||
DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
|
||||
DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
|
||||
parametricUniquers;
|
||||
|
||||
/// Map of type ids to a singleton instance when the storage class is a
|
||||
/// singleton.
|
||||
DenseMap<TypeID, BaseStorage *> singletonInstances;
|
||||
|
||||
/// Allocator used for uniquing singleton instances.
|
||||
StorageAllocator singletonAllocator;
|
||||
|
||||
/// Flag specifying if multi-threading is enabled within the uniquer.
|
||||
bool threadingIsEnabled = true;
|
||||
|
@ -229,41 +205,47 @@ void StorageUniquer::disableMultithreading(bool disable) {
|
|||
impl->threadingIsEnabled = !disable;
|
||||
}
|
||||
|
||||
/// Register a new storage object with this uniquer using the given unique type
|
||||
/// id.
|
||||
void StorageUniquer::registerStorageType(TypeID id) {
|
||||
impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
|
||||
}
|
||||
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// complex storage.
|
||||
auto StorageUniquer::getImpl(
|
||||
const TypeID &id, unsigned kind, unsigned hashValue,
|
||||
/// parametric storage.
|
||||
auto StorageUniquer::getParametricStorageTypeImpl(
|
||||
TypeID id, unsigned hashValue,
|
||||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
|
||||
return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
|
||||
return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
|
||||
}
|
||||
|
||||
/// Implementation for getting/creating an instance of a derived type with
|
||||
/// default storage.
|
||||
auto StorageUniquer::getImpl(
|
||||
const TypeID &id, unsigned kind,
|
||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
|
||||
return impl->getOrCreate(id, kind, ctorFn);
|
||||
/// Implementation for registering an instance of a derived type with
|
||||
/// parametric storage.
|
||||
void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
|
||||
impl->parametricUniquers.try_emplace(
|
||||
id, std::make_unique<ParametricStorageUniquer>());
|
||||
}
|
||||
|
||||
/// Implementation for erasing an instance of a derived type with complex
|
||||
/// Implementation for getting an instance of a derived type with default
|
||||
/// storage.
|
||||
void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
|
||||
unsigned hashValue,
|
||||
auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
|
||||
return impl->getSingleton(id);
|
||||
}
|
||||
|
||||
/// Implementation for registering an instance of a derived type with default
|
||||
/// storage.
|
||||
void StorageUniquer::registerSingletonImpl(
|
||||
TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||
assert(!impl->singletonInstances.count(id) &&
|
||||
"storage class already registered");
|
||||
impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
|
||||
}
|
||||
|
||||
/// Implementation for erasing an instance of a derived type with parametric
|
||||
/// storage.
|
||||
void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue,
|
||||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<void(BaseStorage *)> cleanupFn) {
|
||||
impl->erase(id, kind, hashValue, isEqual, cleanupFn);
|
||||
impl->erase(id, hashValue, isEqual, cleanupFn);
|
||||
}
|
||||
|
||||
/// Implementation for mutating an instance of a derived storage.
|
||||
LogicalResult StorageUniquer::mutateImpl(
|
||||
const TypeID &id,
|
||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||
TypeID id, function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||
return impl->mutate(id, mutationFn);
|
||||
}
|
||||
|
|
|
@ -156,7 +156,7 @@ static Type parseTestType(DialectAsmParser &parser,
|
|||
StringRef name;
|
||||
if (parser.parseLess() || parser.parseKeyword(&name))
|
||||
return Type();
|
||||
auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
|
||||
auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
|
||||
|
||||
// If this type already has been parsed above in the stack, expect just the
|
||||
// name.
|
||||
|
|
|
@ -26,10 +26,6 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
|
|||
TestTypeInterface::Trait> {
|
||||
using Base::Base;
|
||||
|
||||
static TestType get(MLIRContext *context) {
|
||||
return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
|
||||
}
|
||||
|
||||
/// Provide a definition for the necessary interface methods.
|
||||
void printTypeC(Location loc) const {
|
||||
emitRemark(loc) << *this << " - TestC";
|
||||
|
@ -72,9 +68,8 @@ class TestRecursiveType
|
|||
public:
|
||||
using Base::Base;
|
||||
|
||||
static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
|
||||
return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
|
||||
name);
|
||||
static TestRecursiveType get(MLIRContext *ctx, StringRef name) {
|
||||
return Base::get(ctx, name);
|
||||
}
|
||||
|
||||
/// Body getter and setter.
|
||||
|
|
|
@ -41,7 +41,7 @@ struct TestRecursiveTypesPass
|
|||
LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
|
||||
MLIRContext *ctx = &getContext();
|
||||
FuncOp func = getFunction();
|
||||
auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
|
||||
auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
|
||||
if (failed(type.setBody(type)))
|
||||
return func.emitError("expected to be able to set the type body");
|
||||
|
||||
|
@ -56,7 +56,7 @@ LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
|
|||
"not expected to be able to change function body more than once");
|
||||
|
||||
// Expecting to get the same type for the same name.
|
||||
auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
|
||||
auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
|
||||
if (type != other)
|
||||
return func.emitError("expected type name to be the uniquing key");
|
||||
|
||||
|
|
Loading…
Reference in New Issue