forked from OSchip/llvm-project
[mlir] Remove the use of "kinds" from Attributes and Types
This greatly simplifies a large portion of the underlying infrastructure, allows for lookups of singleton classes to be much more efficient and always thread-safe(no locking). As a result of this, the dialect symbol registry has been removed as it is no longer necessary. For users broken by this change, an alert was sent out(https://llvm.discourse.group/t/removing-kinds-from-attributes-and-types) that helps prevent a majority of the breakage surface area. All that should be necessary, if the advice in that alert was followed, is removing the kind passed to the ::get methods. Differential Revision: https://reviews.llvm.org/D86121
This commit is contained in:
parent
a7d0b7a786
commit
250f43d3ec
|
@ -25,17 +25,6 @@ struct RealAttributeStorage;
|
||||||
struct TypeAttributeStorage;
|
struct TypeAttributeStorage;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
enum AttributeKind {
|
|
||||||
FIR_ATTR = mlir::Attribute::FIRST_FIR_ATTR,
|
|
||||||
FIR_EXACTTYPE, // instance_of, precise type relation
|
|
||||||
FIR_SUBCLASS, // subsumed_by, is-a (subclass) relation
|
|
||||||
FIR_POINT,
|
|
||||||
FIR_CLOSEDCLOSED_INTERVAL,
|
|
||||||
FIR_OPENCLOSED_INTERVAL,
|
|
||||||
FIR_CLOSEDOPEN_INTERVAL,
|
|
||||||
FIR_REAL_ATTR,
|
|
||||||
};
|
|
||||||
|
|
||||||
class ExactTypeAttr
|
class ExactTypeAttr
|
||||||
: public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
|
: public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
|
||||||
detail::TypeAttributeStorage> {
|
detail::TypeAttributeStorage> {
|
||||||
|
@ -47,8 +36,6 @@ public:
|
||||||
static ExactTypeAttr get(mlir::Type value);
|
static ExactTypeAttr get(mlir::Type value);
|
||||||
|
|
||||||
mlir::Type getType() const;
|
mlir::Type getType() const;
|
||||||
|
|
||||||
static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class SubclassAttr
|
class SubclassAttr
|
||||||
|
@ -62,8 +49,6 @@ public:
|
||||||
static SubclassAttr get(mlir::Type value);
|
static SubclassAttr get(mlir::Type value);
|
||||||
|
|
||||||
mlir::Type getType() const;
|
mlir::Type getType() const;
|
||||||
|
|
||||||
static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Attributes for building SELECT CASE multiway branches
|
// Attributes for building SELECT CASE multiway branches
|
||||||
|
@ -80,9 +65,6 @@ public:
|
||||||
|
|
||||||
static constexpr llvm::StringRef getAttrName() { return "interval"; }
|
static constexpr llvm::StringRef getAttrName() { return "interval"; }
|
||||||
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
|
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
|
||||||
static constexpr unsigned getId() {
|
|
||||||
return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An upper bound is an open interval (including the bound value) as given as
|
/// An upper bound is an open interval (including the bound value) as given as
|
||||||
|
@ -97,9 +79,6 @@ public:
|
||||||
|
|
||||||
static constexpr llvm::StringRef getAttrName() { return "upper"; }
|
static constexpr llvm::StringRef getAttrName() { return "upper"; }
|
||||||
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
|
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
|
||||||
static constexpr unsigned getId() {
|
|
||||||
return AttributeKind::FIR_OPENCLOSED_INTERVAL;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A lower bound is an open interval (including the bound value) as given as
|
/// A lower bound is an open interval (including the bound value) as given as
|
||||||
|
@ -114,9 +93,6 @@ public:
|
||||||
|
|
||||||
static constexpr llvm::StringRef getAttrName() { return "lower"; }
|
static constexpr llvm::StringRef getAttrName() { return "lower"; }
|
||||||
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
|
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
|
||||||
static constexpr unsigned getId() {
|
|
||||||
return AttributeKind::FIR_CLOSEDOPEN_INTERVAL;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A pointer interval is a closed interval as given as an ssa-value. The
|
/// A pointer interval is a closed interval as given as an ssa-value. The
|
||||||
|
@ -131,7 +107,6 @@ public:
|
||||||
|
|
||||||
static constexpr llvm::StringRef getAttrName() { return "point"; }
|
static constexpr llvm::StringRef getAttrName() { return "point"; }
|
||||||
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
|
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
|
||||||
static constexpr unsigned getId() { return AttributeKind::FIR_POINT; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A real attribute is used to workaround MLIR's default parsing of a real
|
/// A real attribute is used to workaround MLIR's default parsing of a real
|
||||||
|
@ -150,8 +125,6 @@ public:
|
||||||
|
|
||||||
int getFKind() const;
|
int getFKind() const;
|
||||||
llvm::APFloat getValue() const;
|
llvm::APFloat getValue() const;
|
||||||
|
|
||||||
static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
|
mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
|
||||||
|
|
|
@ -54,29 +54,6 @@ struct SequenceTypeStorage;
|
||||||
struct TypeDescTypeStorage;
|
struct TypeDescTypeStorage;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/// Integral identifier for all the types comprising the FIR type system
|
|
||||||
enum TypeKind {
|
|
||||||
// The enum starts at the range reserved for this dialect.
|
|
||||||
FIR_TYPE = mlir::Type::FIRST_FIR_TYPE,
|
|
||||||
FIR_BOX, // (static) descriptor
|
|
||||||
FIR_BOXCHAR, // CHARACTER pointer and length
|
|
||||||
FIR_BOXPROC, // procedure with host association
|
|
||||||
FIR_CHARACTER, // intrinsic type
|
|
||||||
FIR_COMPLEX, // intrinsic type
|
|
||||||
FIR_DERIVED, // derived
|
|
||||||
FIR_DIMS,
|
|
||||||
FIR_FIELD,
|
|
||||||
FIR_HEAP,
|
|
||||||
FIR_INT, // intrinsic type
|
|
||||||
FIR_LEN,
|
|
||||||
FIR_LOGICAL, // intrinsic type
|
|
||||||
FIR_POINTER, // POINTER attr
|
|
||||||
FIR_REAL, // intrinsic type
|
|
||||||
FIR_REFERENCE,
|
|
||||||
FIR_SEQUENCE, // DIMENSION attr
|
|
||||||
FIR_TYPEDESC,
|
|
||||||
};
|
|
||||||
|
|
||||||
// These isa_ routines follow the precedent of llvm::isa_or_null<>
|
// These isa_ routines follow the precedent of llvm::isa_or_null<>
|
||||||
|
|
||||||
/// Is `t` any of the FIR dialect types?
|
/// Is `t` any of the FIR dialect types?
|
||||||
|
@ -111,12 +88,6 @@ bool isa_aggregate(mlir::Type t);
|
||||||
/// not a memory reference type, then returns a null `Type`.
|
/// not a memory reference type, then returns a null `Type`.
|
||||||
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
|
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
|
||||||
|
|
||||||
/// Boilerplate mixin template
|
|
||||||
template <typename A, unsigned Id>
|
|
||||||
struct IntrinsicTypeMixin {
|
|
||||||
static constexpr unsigned getId() { return Id; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Intrinsic types
|
// Intrinsic types
|
||||||
|
|
||||||
/// Model of the Fortran CHARACTER intrinsic type, including the KIND type
|
/// Model of the Fortran CHARACTER intrinsic type, including the KIND type
|
||||||
|
@ -124,8 +95,7 @@ struct IntrinsicTypeMixin {
|
||||||
/// is thus the type of a single character value.
|
/// is thus the type of a single character value.
|
||||||
class CharacterType
|
class CharacterType
|
||||||
: public mlir::Type::TypeBase<CharacterType, mlir::Type,
|
: public mlir::Type::TypeBase<CharacterType, mlir::Type,
|
||||||
detail::CharacterTypeStorage>,
|
detail::CharacterTypeStorage> {
|
||||||
public IntrinsicTypeMixin<CharacterType, TypeKind::FIR_CHARACTER> {
|
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind);
|
static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||||
|
@ -136,8 +106,7 @@ public:
|
||||||
/// parameter. COMPLEX is a floating point type with a real and imaginary
|
/// parameter. COMPLEX is a floating point type with a real and imaginary
|
||||||
/// member.
|
/// member.
|
||||||
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
|
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
|
||||||
detail::CplxTypeStorage>,
|
detail::CplxTypeStorage> {
|
||||||
public IntrinsicTypeMixin<CplxType, TypeKind::FIR_COMPLEX> {
|
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
|
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||||
|
@ -151,8 +120,7 @@ public:
|
||||||
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
|
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
|
||||||
/// parameter.
|
/// parameter.
|
||||||
class IntType
|
class IntType
|
||||||
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage>,
|
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
|
||||||
public IntrinsicTypeMixin<IntType, TypeKind::FIR_INT> {
|
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
|
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||||
|
@ -163,8 +131,7 @@ public:
|
||||||
/// parameter.
|
/// parameter.
|
||||||
class LogicalType
|
class LogicalType
|
||||||
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
|
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
|
||||||
detail::LogicalTypeStorage>,
|
detail::LogicalTypeStorage> {
|
||||||
public IntrinsicTypeMixin<LogicalType, TypeKind::FIR_LOGICAL> {
|
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
|
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||||
|
@ -174,8 +141,7 @@ public:
|
||||||
/// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
|
/// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
|
||||||
/// KIND type parameter.
|
/// KIND type parameter.
|
||||||
class RealType : public mlir::Type::TypeBase<RealType, mlir::Type,
|
class RealType : public mlir::Type::TypeBase<RealType, mlir::Type,
|
||||||
detail::RealTypeStorage>,
|
detail::RealTypeStorage> {
|
||||||
public IntrinsicTypeMixin<RealType, TypeKind::FIR_REAL> {
|
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
static RealType get(mlir::MLIRContext *ctxt, KindTy kind);
|
static RealType get(mlir::MLIRContext *ctxt, KindTy kind);
|
||||||
|
@ -400,7 +366,6 @@ public:
|
||||||
static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
|
static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
|
||||||
void finalize(llvm::ArrayRef<TypePair> lenPList,
|
void finalize(llvm::ArrayRef<TypePair> lenPList,
|
||||||
llvm::ArrayRef<TypePair> typeList);
|
llvm::ArrayRef<TypePair> typeList);
|
||||||
static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; }
|
|
||||||
|
|
||||||
detail::RecordTypeStorage const *uniqueKey() const;
|
detail::RecordTypeStorage const *uniqueKey() const;
|
||||||
|
|
||||||
|
|
|
@ -74,13 +74,13 @@ private:
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
|
ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
|
||||||
return Base::get(value.getContext(), FIR_EXACTTYPE, value);
|
return Base::get(value.getContext(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
|
mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
|
||||||
|
|
||||||
SubclassAttr SubclassAttr::get(mlir::Type value) {
|
SubclassAttr SubclassAttr::get(mlir::Type value) {
|
||||||
return Base::get(value.getContext(), FIR_SUBCLASS, value);
|
return Base::get(value.getContext(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
|
mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
|
||||||
|
@ -88,26 +88,26 @@ mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
|
||||||
using AttributeUniquer = mlir::detail::AttributeUniquer;
|
using AttributeUniquer = mlir::detail::AttributeUniquer;
|
||||||
|
|
||||||
ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
|
ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
|
||||||
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt, getId());
|
return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
|
||||||
}
|
}
|
||||||
|
|
||||||
UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
|
UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
|
||||||
return AttributeUniquer::get<UpperBoundAttr>(ctxt, getId());
|
return AttributeUniquer::get<UpperBoundAttr>(ctxt);
|
||||||
}
|
}
|
||||||
|
|
||||||
LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
|
LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
|
||||||
return AttributeUniquer::get<LowerBoundAttr>(ctxt, getId());
|
return AttributeUniquer::get<LowerBoundAttr>(ctxt);
|
||||||
}
|
}
|
||||||
|
|
||||||
PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
|
PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
|
||||||
return AttributeUniquer::get<PointIntervalAttr>(ctxt, getId());
|
return AttributeUniquer::get<PointIntervalAttr>(ctxt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RealAttr
|
// RealAttr
|
||||||
|
|
||||||
RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
|
RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
|
||||||
const RealAttr::ValueType &key) {
|
const RealAttr::ValueType &key) {
|
||||||
return Base::get(ctxt, getId(), key);
|
return Base::get(ctxt, key);
|
||||||
}
|
}
|
||||||
|
|
||||||
int RealAttr::getFKind() const { return getImpl()->getFKind(); }
|
int RealAttr::getFKind() const { return getImpl()->getFKind(); }
|
||||||
|
|
|
@ -824,13 +824,11 @@ bool inbounds(A v, B lb, B ub) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isa_fir_type(mlir::Type t) {
|
bool isa_fir_type(mlir::Type t) {
|
||||||
return inbounds(t.getKind(), mlir::Type::FIRST_FIR_TYPE,
|
return llvm::isa<FIROpsDialect>(t.getDialect());
|
||||||
mlir::Type::LAST_FIR_TYPE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isa_std_type(mlir::Type t) {
|
bool isa_std_type(mlir::Type t) {
|
||||||
return inbounds(t.getKind(), mlir::Type::FIRST_STANDARD_TYPE,
|
return t.getDialect().getNamespace().empty();
|
||||||
mlir::Type::LAST_STANDARD_TYPE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isa_fir_or_std_type(mlir::Type t) {
|
bool isa_fir_or_std_type(mlir::Type t) {
|
||||||
|
@ -868,7 +866,7 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
|
||||||
// CHARACTER
|
// CHARACTER
|
||||||
|
|
||||||
CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||||
return Base::get(ctxt, FIR_CHARACTER, kind);
|
return Base::get(ctxt, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
|
int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
|
@ -876,7 +874,7 @@ int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
// Dims
|
// Dims
|
||||||
|
|
||||||
DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) {
|
DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) {
|
||||||
return Base::get(ctxt, FIR_DIMS, rank);
|
return Base::get(ctxt, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
|
unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
|
||||||
|
@ -884,19 +882,19 @@ unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
|
||||||
// Field
|
// Field
|
||||||
|
|
||||||
FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) {
|
FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) {
|
||||||
return Base::get(ctxt, FIR_FIELD, 0);
|
return Base::get(ctxt, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len
|
// Len
|
||||||
|
|
||||||
LenType fir::LenType::get(mlir::MLIRContext *ctxt) {
|
LenType fir::LenType::get(mlir::MLIRContext *ctxt) {
|
||||||
return Base::get(ctxt, FIR_LEN, 0);
|
return Base::get(ctxt, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// LOGICAL
|
// LOGICAL
|
||||||
|
|
||||||
LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||||
return Base::get(ctxt, FIR_LOGICAL, kind);
|
return Base::get(ctxt, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
|
int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
|
@ -904,7 +902,7 @@ int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
// INTEGER
|
// INTEGER
|
||||||
|
|
||||||
IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||||
return Base::get(ctxt, FIR_INT, kind);
|
return Base::get(ctxt, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
|
int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
|
@ -912,7 +910,7 @@ int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
// COMPLEX
|
// COMPLEX
|
||||||
|
|
||||||
CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||||
return Base::get(ctxt, FIR_COMPLEX, kind);
|
return Base::get(ctxt, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::CplxType::getElementType() const {
|
mlir::Type fir::CplxType::getElementType() const {
|
||||||
|
@ -924,7 +922,7 @@ KindTy fir::CplxType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
// REAL
|
// REAL
|
||||||
|
|
||||||
RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||||
return Base::get(ctxt, FIR_REAL, kind);
|
return Base::get(ctxt, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
|
int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
|
@ -932,7 +930,7 @@ int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
|
||||||
// Box<T>
|
// Box<T>
|
||||||
|
|
||||||
BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) {
|
BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) {
|
||||||
return Base::get(elementType.getContext(), FIR_BOX, elementType, map);
|
return Base::get(elementType.getContext(), elementType, map);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::BoxType::getEleTy() const {
|
mlir::Type fir::BoxType::getEleTy() const {
|
||||||
|
@ -953,7 +951,7 @@ fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
|
||||||
// BoxChar<C>
|
// BoxChar<C>
|
||||||
|
|
||||||
BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) {
|
||||||
return Base::get(ctxt, FIR_BOXCHAR, kind);
|
return Base::get(ctxt, kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
CharacterType fir::BoxCharType::getEleTy() const {
|
CharacterType fir::BoxCharType::getEleTy() const {
|
||||||
|
@ -963,7 +961,7 @@ CharacterType fir::BoxCharType::getEleTy() const {
|
||||||
// BoxProc<T>
|
// BoxProc<T>
|
||||||
|
|
||||||
BoxProcType fir::BoxProcType::get(mlir::Type elementType) {
|
BoxProcType fir::BoxProcType::get(mlir::Type elementType) {
|
||||||
return Base::get(elementType.getContext(), FIR_BOXPROC, elementType);
|
return Base::get(elementType.getContext(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::BoxProcType::getEleTy() const {
|
mlir::Type fir::BoxProcType::getEleTy() const {
|
||||||
|
@ -984,7 +982,7 @@ fir::BoxProcType::verifyConstructionInvariants(mlir::Location loc,
|
||||||
// Reference<T>
|
// Reference<T>
|
||||||
|
|
||||||
ReferenceType fir::ReferenceType::get(mlir::Type elementType) {
|
ReferenceType fir::ReferenceType::get(mlir::Type elementType) {
|
||||||
return Base::get(elementType.getContext(), FIR_REFERENCE, elementType);
|
return Base::get(elementType.getContext(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::ReferenceType::getEleTy() const {
|
mlir::Type fir::ReferenceType::getEleTy() const {
|
||||||
|
@ -1005,7 +1003,7 @@ fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
|
||||||
|
|
||||||
PointerType fir::PointerType::get(mlir::Type elementType) {
|
PointerType fir::PointerType::get(mlir::Type elementType) {
|
||||||
assert(singleIndirectionLevel(elementType) && "invalid element type");
|
assert(singleIndirectionLevel(elementType) && "invalid element type");
|
||||||
return Base::get(elementType.getContext(), FIR_POINTER, elementType);
|
return Base::get(elementType.getContext(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::PointerType::getEleTy() const {
|
mlir::Type fir::PointerType::getEleTy() const {
|
||||||
|
@ -1033,7 +1031,7 @@ fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
|
||||||
|
|
||||||
HeapType fir::HeapType::get(mlir::Type elementType) {
|
HeapType fir::HeapType::get(mlir::Type elementType) {
|
||||||
assert(singleIndirectionLevel(elementType) && "invalid element type");
|
assert(singleIndirectionLevel(elementType) && "invalid element type");
|
||||||
return Base::get(elementType.getContext(), FIR_HEAP, elementType);
|
return Base::get(elementType.getContext(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::HeapType::getEleTy() const {
|
mlir::Type fir::HeapType::getEleTy() const {
|
||||||
|
@ -1054,7 +1052,7 @@ fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
|
||||||
SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType,
|
SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType,
|
||||||
mlir::AffineMapAttr map) {
|
mlir::AffineMapAttr map) {
|
||||||
auto *ctxt = elementType.getContext();
|
auto *ctxt = elementType.getContext();
|
||||||
return Base::get(ctxt, FIR_SEQUENCE, shape, elementType, map);
|
return Base::get(ctxt, shape, elementType, map);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::SequenceType::getEleTy() const {
|
mlir::Type fir::SequenceType::getEleTy() const {
|
||||||
|
@ -1136,7 +1134,7 @@ llvm::hash_code fir::hash_value(const SequenceType::Shape &sh) {
|
||||||
/// This type captures a Fortran "derived type"
|
/// This type captures a Fortran "derived type"
|
||||||
|
|
||||||
RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) {
|
RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) {
|
||||||
return Base::get(ctxt, FIR_DERIVED, name);
|
return Base::get(ctxt, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
|
void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
|
||||||
|
@ -1179,7 +1177,7 @@ mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
|
||||||
|
|
||||||
TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
|
TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
|
||||||
assert(!ofType.isa<ReferenceType>());
|
assert(!ofType.isa<ReferenceType>());
|
||||||
return Base::get(ofType.getContext(), FIR_TYPEDESC, ofType);
|
return Base::get(ofType.getContext(), ofType);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
|
mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
|
||||||
|
@ -1222,9 +1220,7 @@ void fir::verifyIntegralType(mlir::Type type) {
|
||||||
void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
||||||
mlir::DialectAsmPrinter &p) {
|
mlir::DialectAsmPrinter &p) {
|
||||||
auto &os = p.getStream();
|
auto &os = p.getStream();
|
||||||
switch (ty.getKind()) {
|
if (auto type = ty.dyn_cast<BoxType>()) {
|
||||||
case fir::FIR_BOX: {
|
|
||||||
auto type = ty.cast<BoxType>();
|
|
||||||
os << "box<";
|
os << "box<";
|
||||||
p.printType(type.getEleTy());
|
p.printType(type.getEleTy());
|
||||||
if (auto map = type.getLayoutMap()) {
|
if (auto map = type.getLayoutMap()) {
|
||||||
|
@ -1232,24 +1228,28 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
||||||
p.printAttribute(map);
|
p.printAttribute(map);
|
||||||
}
|
}
|
||||||
os << '>';
|
os << '>';
|
||||||
} break;
|
return;
|
||||||
case fir::FIR_BOXCHAR: {
|
}
|
||||||
auto type = ty.cast<BoxCharType>().getEleTy();
|
if (auto type = ty.dyn_cast<BoxCharType>()) {
|
||||||
os << "boxchar<" << type.cast<fir::CharacterType>().getFKind() << '>';
|
os << "boxchar<" << type.getEleTy().cast<fir::CharacterType>().getFKind()
|
||||||
} break;
|
<< '>';
|
||||||
case fir::FIR_BOXPROC:
|
return;
|
||||||
|
}
|
||||||
|
if (auto type = ty.dyn_cast<BoxProcType>()) {
|
||||||
os << "boxproc<";
|
os << "boxproc<";
|
||||||
p.printType(ty.cast<BoxProcType>().getEleTy());
|
p.printType(type.getEleTy());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
return;
|
||||||
case fir::FIR_CHARACTER: // intrinsic
|
}
|
||||||
os << "char<" << ty.cast<CharacterType>().getFKind() << '>';
|
if (auto type = ty.dyn_cast<CharacterType>()) {
|
||||||
break;
|
os << "char<" << type.getFKind() << '>';
|
||||||
case fir::FIR_COMPLEX: // intrinsic
|
return;
|
||||||
os << "complex<" << ty.cast<CplxType>().getFKind() << '>';
|
}
|
||||||
break;
|
if (auto type = ty.dyn_cast<CplxType>()) {
|
||||||
case fir::FIR_DERIVED: { // derived
|
os << "complex<" << type.getFKind() << '>';
|
||||||
auto type = ty.cast<fir::RecordType>();
|
return;
|
||||||
|
}
|
||||||
|
if (auto type = ty.dyn_cast<RecordType>()) {
|
||||||
os << "type<" << type.getName();
|
os << "type<" << type.getName();
|
||||||
if (!recordTypeVisited.count(type.uniqueKey())) {
|
if (!recordTypeVisited.count(type.uniqueKey())) {
|
||||||
recordTypeVisited.insert(type.uniqueKey());
|
recordTypeVisited.insert(type.uniqueKey());
|
||||||
|
@ -1274,43 +1274,52 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
||||||
recordTypeVisited.erase(type.uniqueKey());
|
recordTypeVisited.erase(type.uniqueKey());
|
||||||
}
|
}
|
||||||
os << '>';
|
os << '>';
|
||||||
} break;
|
return;
|
||||||
case fir::FIR_DIMS:
|
}
|
||||||
os << "dims<" << ty.cast<DimsType>().getRank() << '>';
|
if (auto type = ty.dyn_cast<DimsType>()) {
|
||||||
break;
|
os << "dims<" << type.getRank() << '>';
|
||||||
case fir::FIR_FIELD:
|
return;
|
||||||
|
}
|
||||||
|
if (ty.isa<FieldType>()) {
|
||||||
os << "field";
|
os << "field";
|
||||||
break;
|
return;
|
||||||
case fir::FIR_HEAP:
|
}
|
||||||
|
if (auto type = ty.dyn_cast<HeapType>()) {
|
||||||
os << "heap<";
|
os << "heap<";
|
||||||
p.printType(ty.cast<HeapType>().getEleTy());
|
p.printType(type.getEleTy());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
return;
|
||||||
case fir::FIR_INT: // intrinsic
|
}
|
||||||
os << "int<" << ty.cast<fir::IntType>().getFKind() << '>';
|
if (auto type = ty.dyn_cast<fir::IntType>()) {
|
||||||
break;
|
os << "int<" << type.getFKind() << '>';
|
||||||
case fir::FIR_LEN:
|
return;
|
||||||
|
}
|
||||||
|
if (auto type = ty.dyn_cast<LenType>()) {
|
||||||
os << "len";
|
os << "len";
|
||||||
break;
|
return;
|
||||||
case fir::FIR_LOGICAL: // intrinsic
|
}
|
||||||
os << "logical<" << ty.cast<LogicalType>().getFKind() << '>';
|
if (auto type = ty.dyn_cast<LogicalType>()) {
|
||||||
break;
|
os << "logical<" << type.getFKind() << '>';
|
||||||
case fir::FIR_POINTER:
|
return;
|
||||||
|
}
|
||||||
|
if (auto type = ty.dyn_cast<PointerType>()) {
|
||||||
os << "ptr<";
|
os << "ptr<";
|
||||||
p.printType(ty.cast<PointerType>().getEleTy());
|
p.printType(type.getEleTy());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
return;
|
||||||
case fir::FIR_REAL: // intrinsic
|
}
|
||||||
os << "real<" << ty.cast<fir::RealType>().getFKind() << '>';
|
if (auto type = ty.dyn_cast<fir::RealType>()) {
|
||||||
break;
|
os << "real<" << type.getFKind() << '>';
|
||||||
case fir::FIR_REFERENCE:
|
return;
|
||||||
|
}
|
||||||
|
if (auto type = ty.dyn_cast<ReferenceType>()) {
|
||||||
os << "ref<";
|
os << "ref<";
|
||||||
p.printType(ty.cast<ReferenceType>().getEleTy());
|
p.printType(type.getEleTy());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
return;
|
||||||
case fir::FIR_SEQUENCE: {
|
}
|
||||||
|
if (auto type = ty.dyn_cast<SequenceType>()) {
|
||||||
os << "array";
|
os << "array";
|
||||||
auto type = ty.cast<SequenceType>();
|
|
||||||
auto shape = type.getShape();
|
auto shape = type.getShape();
|
||||||
if (shape.size()) {
|
if (shape.size()) {
|
||||||
printBounds(os, shape);
|
printBounds(os, shape);
|
||||||
|
@ -1323,11 +1332,12 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
|
||||||
map.print(os);
|
map.print(os);
|
||||||
}
|
}
|
||||||
os << '>';
|
os << '>';
|
||||||
} break;
|
return;
|
||||||
case fir::FIR_TYPEDESC:
|
}
|
||||||
|
if (auto type = ty.dyn_cast<TypeDescType>()) {
|
||||||
os << "tdesc<";
|
os << "tdesc<";
|
||||||
p.printType(ty.cast<TypeDescType>().getOfTy());
|
p.printType(type.getOfTy());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,11 +190,10 @@ public:
|
||||||
assert(!elementTypes.empty() && "expected at least 1 element type");
|
assert(!elementTypes.empty() && "expected at least 1 element type");
|
||||||
|
|
||||||
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
||||||
// of this type. The first two parameters are the context to unique in and
|
// of this type. The first parameter is the context to unique in. The
|
||||||
// the kind of the type. The parameters after the type kind are forwarded to
|
// parameters after the type kind are forwarded to the storage instance.
|
||||||
// the storage instance.
|
|
||||||
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
||||||
return Base::get(ctx, ToyTypes::Struct, elementTypes);
|
return Base::get(ctx, elementTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the element types of this struct type.
|
/// Returns the element types of this struct type.
|
||||||
|
|
|
@ -63,13 +63,6 @@ public:
|
||||||
// Toy Types
|
// Toy Types
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Create a local enumeration with all of the types that are defined by Toy.
|
|
||||||
namespace ToyTypes {
|
|
||||||
enum Types {
|
|
||||||
Struct = mlir::Type::FIRST_TOY_TYPE,
|
|
||||||
};
|
|
||||||
} // end namespace ToyTypes
|
|
||||||
|
|
||||||
/// This class defines the Toy struct type. It represents a collection of
|
/// This class defines the Toy struct type. It represents a collection of
|
||||||
/// element types. All derived types in MLIR must inherit from the CRTP class
|
/// element types. All derived types in MLIR must inherit from the CRTP class
|
||||||
/// 'Type::TypeBase'. It takes as template parameters the concrete type
|
/// 'Type::TypeBase'. It takes as template parameters the concrete type
|
||||||
|
|
|
@ -474,11 +474,10 @@ StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
|
||||||
assert(!elementTypes.empty() && "expected at least 1 element type");
|
assert(!elementTypes.empty() && "expected at least 1 element type");
|
||||||
|
|
||||||
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
|
||||||
// of this type. The first two parameters are the context to unique in and the
|
// of this type. The first parameter is the context to unique in. The
|
||||||
// kind of the type. The parameters after the type kind are forwarded to the
|
// parameters after the type kind are forwarded to the storage instance.
|
||||||
// storage instance.
|
|
||||||
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
||||||
return Base::get(ctx, ToyTypes::Struct, elementTypes);
|
return Base::get(ctx, elementTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the element types of this struct type.
|
/// Returns the element types of this struct type.
|
||||||
|
|
|
@ -64,34 +64,6 @@ class LLVMIntegerType;
|
||||||
/// structs, the entire type is the identifier) and are thread-safe.
|
/// structs, the entire type is the identifier) and are thread-safe.
|
||||||
class LLVMType : public Type {
|
class LLVMType : public Type {
|
||||||
public:
|
public:
|
||||||
enum Kind {
|
|
||||||
// Keep non-parametric types contiguous in the enum.
|
|
||||||
VoidType = FIRST_LLVM_TYPE + 1,
|
|
||||||
HalfType,
|
|
||||||
BFloatType,
|
|
||||||
FloatType,
|
|
||||||
DoubleType,
|
|
||||||
FP128Type,
|
|
||||||
X86FP80Type,
|
|
||||||
PPCFP128Type,
|
|
||||||
X86MMXType,
|
|
||||||
LabelType,
|
|
||||||
TokenType,
|
|
||||||
MetadataType,
|
|
||||||
// End of non-parametric types.
|
|
||||||
FunctionType,
|
|
||||||
IntegerType,
|
|
||||||
PointerType,
|
|
||||||
FixedVectorType,
|
|
||||||
ScalableVectorType,
|
|
||||||
ArrayType,
|
|
||||||
StructType,
|
|
||||||
FIRST_NEW_LLVM_TYPE = VoidType,
|
|
||||||
LAST_NEW_LLVM_TYPE = StructType,
|
|
||||||
FIRST_TRIVIAL_TYPE = VoidType,
|
|
||||||
LAST_TRIVIAL_TYPE = MetadataType
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Inherit base constructors.
|
/// Inherit base constructors.
|
||||||
using Type::Type;
|
using Type::Type;
|
||||||
|
|
||||||
|
@ -256,27 +228,24 @@ public:
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Batch-define trivial types.
|
// Batch-define trivial types.
|
||||||
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \
|
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
|
||||||
class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
|
class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
|
||||||
public: \
|
public: \
|
||||||
using Base::Base; \
|
using Base::Base; \
|
||||||
static ClassName get(MLIRContext *context) { \
|
|
||||||
return Base::get(context, Kind); \
|
|
||||||
} \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMType::VoidType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMType::HalfType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMType::BFloatType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMType::FloatType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMType::DoubleType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMType::FP128Type);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMType::X86FP80Type);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMType::PPCFP128Type);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMType::X86MMXType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMType::TokenType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMType::LabelType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
|
||||||
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMType::MetadataType);
|
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
|
||||||
|
|
||||||
#undef DEFINE_TRIVIAL_LLVM_TYPE
|
#undef DEFINE_TRIVIAL_LLVM_TYPE
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,6 @@ namespace mlir {
|
||||||
class MLIRContext;
|
class MLIRContext;
|
||||||
|
|
||||||
namespace linalg {
|
namespace linalg {
|
||||||
enum LinalgTypes {
|
|
||||||
Range = Type::FIRST_LINALG_TYPE,
|
|
||||||
LAST_USED_LINALG_TYPE = Range,
|
|
||||||
};
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||||
|
|
||||||
/// A RangeType represents a minimal range abstraction (min, max, step).
|
/// A RangeType represents a minimal range abstraction (min, max, step).
|
||||||
|
@ -36,11 +31,6 @@ class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
// Used for generic hooks in TypeBase.
|
// Used for generic hooks in TypeBase.
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
/// Construction hook.
|
|
||||||
static RangeType get(MLIRContext *context) {
|
|
||||||
/// Custom, uniq'ed construction in the MLIRContext.
|
|
||||||
return Base::get(context, LinalgTypes::Range);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace linalg
|
} // namespace linalg
|
||||||
|
|
|
@ -31,15 +31,6 @@ struct UniformQuantizedPerAxisTypeStorage;
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
namespace QuantizationTypes {
|
|
||||||
enum Kind {
|
|
||||||
Any = Type::FIRST_QUANTIZATION_TYPE,
|
|
||||||
UniformQuantized,
|
|
||||||
UniformQuantizedPerAxis,
|
|
||||||
LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
|
|
||||||
};
|
|
||||||
} // namespace QuantizationTypes
|
|
||||||
|
|
||||||
/// Enumeration of bit-mapped flags related to quantized types.
|
/// Enumeration of bit-mapped flags related to quantized types.
|
||||||
namespace QuantizationFlags {
|
namespace QuantizationFlags {
|
||||||
enum FlagValue {
|
enum FlagValue {
|
||||||
|
|
|
@ -32,15 +32,6 @@ struct TargetEnvAttributeStorage;
|
||||||
struct VerCapExtAttributeStorage;
|
struct VerCapExtAttributeStorage;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/// SPIR-V dialect-specific attribute kinds.
|
|
||||||
namespace AttrKind {
|
|
||||||
enum Kind {
|
|
||||||
InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
|
|
||||||
TargetEnv, /// Target environment
|
|
||||||
VerCapExt, /// (version, extension, capability) triple
|
|
||||||
};
|
|
||||||
} // namespace AttrKind
|
|
||||||
|
|
||||||
/// An attribute that specifies the information regarding the interface
|
/// An attribute that specifies the information regarding the interface
|
||||||
/// variable: descriptor set, binding, storage class.
|
/// variable: descriptor set, binding, storage class.
|
||||||
class InterfaceVarABIAttr
|
class InterfaceVarABIAttr
|
||||||
|
|
|
@ -65,19 +65,6 @@ struct StructTypeStorage;
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
namespace TypeKind {
|
|
||||||
enum Kind {
|
|
||||||
Array = Type::FIRST_SPIRV_TYPE,
|
|
||||||
CooperativeMatrix,
|
|
||||||
Image,
|
|
||||||
Matrix,
|
|
||||||
Pointer,
|
|
||||||
RuntimeArray,
|
|
||||||
Struct,
|
|
||||||
LAST_SPIRV_TYPE = Struct,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Base SPIR-V type for providing availability queries.
|
// Base SPIR-V type for providing availability queries.
|
||||||
class SPIRVType : public Type {
|
class SPIRVType : public Type {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -29,56 +29,28 @@ namespace shape {
|
||||||
/// Alias type for extent tensors.
|
/// Alias type for extent tensors.
|
||||||
RankedTensorType getExtentTensorType(MLIRContext *ctx);
|
RankedTensorType getExtentTensorType(MLIRContext *ctx);
|
||||||
|
|
||||||
namespace ShapeTypes {
|
|
||||||
enum Kind {
|
|
||||||
Component = Type::FIRST_SHAPE_TYPE,
|
|
||||||
Element,
|
|
||||||
Shape,
|
|
||||||
Size,
|
|
||||||
ValueShape,
|
|
||||||
Witness,
|
|
||||||
LAST_SHAPE_TYPE = Witness
|
|
||||||
};
|
|
||||||
} // namespace ShapeTypes
|
|
||||||
|
|
||||||
/// The component type corresponding to shape, element type and attribute.
|
/// The component type corresponding to shape, element type and attribute.
|
||||||
class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
|
class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static ComponentType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, ShapeTypes::Kind::Component);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The element type of the shaped type.
|
/// The element type of the shaped type.
|
||||||
class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
|
class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static ElementType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, ShapeTypes::Kind::Element);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The shape descriptor type represents rank and dimension sizes.
|
/// The shape descriptor type represents rank and dimension sizes.
|
||||||
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
|
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static ShapeType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, ShapeTypes::Kind::Shape);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The type of a single dimension.
|
/// The type of a single dimension.
|
||||||
class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
|
class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static SizeType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, ShapeTypes::Kind::Size);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The ValueShape represents a (potentially unknown) runtime value and shape.
|
/// The ValueShape represents a (potentially unknown) runtime value and shape.
|
||||||
|
@ -86,10 +58,6 @@ class ValueShapeType
|
||||||
: public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
|
: public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static ValueShapeType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, ShapeTypes::Kind::ValueShape);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The Witness represents a runtime constraint, to be used as shape related
|
/// The Witness represents a runtime constraint, to be used as shape related
|
||||||
|
@ -97,10 +65,6 @@ public:
|
||||||
class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
|
class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static WitnessType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, ShapeTypes::Kind::Witness);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
|
|
@ -137,15 +137,23 @@ namespace detail {
|
||||||
// MLIRContext. This class manages all creation and uniquing of attributes.
|
// MLIRContext. This class manages all creation and uniquing of attributes.
|
||||||
class AttributeUniquer {
|
class AttributeUniquer {
|
||||||
public:
|
public:
|
||||||
/// Get an uniqued instance of attribute T.
|
/// Get an uniqued instance of a parametric attribute T.
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
static typename std::enable_if_t<
|
||||||
|
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>
|
||||||
|
get(MLIRContext *ctx, Args &&...args) {
|
||||||
return ctx->getAttributeUniquer().get<typename T::ImplType>(
|
return ctx->getAttributeUniquer().get<typename T::ImplType>(
|
||||||
T::getTypeID(),
|
|
||||||
[ctx](AttributeStorage *storage) {
|
[ctx](AttributeStorage *storage) {
|
||||||
initializeAttributeStorage(storage, ctx, T::getTypeID());
|
initializeAttributeStorage(storage, ctx, T::getTypeID());
|
||||||
},
|
},
|
||||||
kind, std::forward<Args>(args)...);
|
T::getTypeID(), std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
/// Get an uniqued instance of a singleton attribute T.
|
||||||
|
template <typename T>
|
||||||
|
static typename std::enable_if_t<
|
||||||
|
std::is_same<typename T::ImplType, AttributeStorage>::value, T>
|
||||||
|
get(MLIRContext *ctx) {
|
||||||
|
return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
|
@ -156,6 +164,26 @@ public:
|
||||||
std::forward<Args>(args)...);
|
std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Register a parametric attribute instance T with the uniquer.
|
||||||
|
template <typename T>
|
||||||
|
static typename std::enable_if_t<
|
||||||
|
!std::is_same<typename T::ImplType, AttributeStorage>::value>
|
||||||
|
registerAttribute(MLIRContext *ctx) {
|
||||||
|
ctx->getAttributeUniquer()
|
||||||
|
.registerParametricStorageType<typename T::ImplType>(T::getTypeID());
|
||||||
|
}
|
||||||
|
/// Register a singleton attribute instance T with the uniquer.
|
||||||
|
template <typename T>
|
||||||
|
static typename std::enable_if_t<
|
||||||
|
std::is_same<typename T::ImplType, AttributeStorage>::value>
|
||||||
|
registerAttribute(MLIRContext *ctx) {
|
||||||
|
ctx->getAttributeUniquer()
|
||||||
|
.registerSingletonStorageType<typename T::ImplType>(
|
||||||
|
T::getTypeID(), [ctx](AttributeStorage *storage) {
|
||||||
|
initializeAttributeStorage(storage, ctx, T::getTypeID());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Initialize the given attribute storage instance.
|
/// Initialize the given attribute storage instance.
|
||||||
static void initializeAttributeStorage(AttributeStorage *storage,
|
static void initializeAttributeStorage(AttributeStorage *storage,
|
||||||
|
|
|
@ -54,14 +54,6 @@ struct SparseElementsAttributeStorage;
|
||||||
/// passed by value.
|
/// passed by value.
|
||||||
class Attribute {
|
class Attribute {
|
||||||
public:
|
public:
|
||||||
/// Integer identifier for all the concrete attribute kinds.
|
|
||||||
enum Kind {
|
|
||||||
// Reserve attribute kinds for dialect specific extensions.
|
|
||||||
#define DEFINE_SYM_KIND_RANGE(Dialect) \
|
|
||||||
FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
|
|
||||||
#include "DialectSymbolRegistry.def"
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Utility class for implementing attributes.
|
/// Utility class for implementing attributes.
|
||||||
template <typename ConcreteType, typename BaseType, typename StorageType,
|
template <typename ConcreteType, typename BaseType, typename StorageType,
|
||||||
template <typename T> class... Traits>
|
template <typename T> class... Traits>
|
||||||
|
@ -94,9 +86,6 @@ public:
|
||||||
// Support dyn_cast'ing Attribute to itself.
|
// Support dyn_cast'ing Attribute to itself.
|
||||||
static bool classof(Attribute) { return true; }
|
static bool classof(Attribute) { return true; }
|
||||||
|
|
||||||
/// Return the classification for this attribute.
|
|
||||||
unsigned getKind() const { return impl->getKind(); }
|
|
||||||
|
|
||||||
/// Return a unique identifier for the concrete attribute type. This is used
|
/// Return a unique identifier for the concrete attribute type. This is used
|
||||||
/// to support dynamic type casting.
|
/// to support dynamic type casting.
|
||||||
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
|
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
|
||||||
|
@ -173,54 +162,6 @@ private:
|
||||||
friend InterfaceBase;
|
friend InterfaceBase;
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// StandardAttributes
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace StandardAttributes {
|
|
||||||
enum Kind {
|
|
||||||
AffineMap = Attribute::FIRST_STANDARD_ATTR,
|
|
||||||
Array,
|
|
||||||
Dictionary,
|
|
||||||
Float,
|
|
||||||
Integer,
|
|
||||||
IntegerSet,
|
|
||||||
Opaque,
|
|
||||||
String,
|
|
||||||
SymbolRef,
|
|
||||||
Type,
|
|
||||||
Unit,
|
|
||||||
|
|
||||||
/// Elements Attributes.
|
|
||||||
DenseIntOrFPElements,
|
|
||||||
DenseStringElements,
|
|
||||||
OpaqueElements,
|
|
||||||
SparseElements,
|
|
||||||
FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
|
|
||||||
LAST_ELEMENTS_ATTR = SparseElements,
|
|
||||||
|
|
||||||
/// Locations.
|
|
||||||
CallSiteLocation,
|
|
||||||
FileLineColLocation,
|
|
||||||
FusedLocation,
|
|
||||||
NameLocation,
|
|
||||||
OpaqueLocation,
|
|
||||||
UnknownLocation,
|
|
||||||
|
|
||||||
// Represents a location as a 'void*' pointer to a front-end's opaque
|
|
||||||
// location information, which must live longer than the MLIR objects that
|
|
||||||
// refer to it. OpaqueLocation's are never serialized.
|
|
||||||
//
|
|
||||||
// TODO: OpaqueLocation,
|
|
||||||
|
|
||||||
// Represents a value inlined through a function call.
|
|
||||||
// TODO: InlinedLocation,
|
|
||||||
|
|
||||||
FIRST_LOCATION_ATTR = CallSiteLocation,
|
|
||||||
LAST_LOCATION_ATTR = UnknownLocation,
|
|
||||||
};
|
|
||||||
} // namespace StandardAttributes
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AffineMapAttr
|
// AffineMapAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -154,21 +154,15 @@ protected:
|
||||||
|
|
||||||
void addOperation(AbstractOperation opInfo);
|
void addOperation(AbstractOperation opInfo);
|
||||||
|
|
||||||
/// This method is used by derived classes to add their types to the set.
|
/// Register a set of type classes with this dialect.
|
||||||
template <typename... Args> void addTypes() {
|
template <typename... Args> void addTypes() {
|
||||||
(void)std::initializer_list<int>{
|
(void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
|
||||||
0, (addType(Args::getTypeID(), AbstractType::get<Args>(*this)), 0)...};
|
|
||||||
}
|
}
|
||||||
void addType(TypeID typeID, AbstractType &&typeInfo);
|
|
||||||
|
|
||||||
/// This method is used by derived classes to add their attributes to the set.
|
/// Register a set of attribute classes with this dialect.
|
||||||
template <typename... Args> void addAttributes() {
|
template <typename... Args> void addAttributes() {
|
||||||
(void)std::initializer_list<int>{
|
(void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
|
||||||
0,
|
|
||||||
(addAttribute(Args::getTypeID(), AbstractAttribute::get<Args>(*this)),
|
|
||||||
0)...};
|
|
||||||
}
|
}
|
||||||
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
|
|
||||||
|
|
||||||
/// Enable support for unregistered operations.
|
/// Enable support for unregistered operations.
|
||||||
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
|
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
|
||||||
|
@ -189,6 +183,22 @@ private:
|
||||||
Dialect(const Dialect &) = delete;
|
Dialect(const Dialect &) = delete;
|
||||||
void operator=(Dialect &) = delete;
|
void operator=(Dialect &) = delete;
|
||||||
|
|
||||||
|
/// Register an attribute instance with this dialect.
|
||||||
|
template <typename T> void addAttribute() {
|
||||||
|
// Add this attribute to the dialect and register it with the uniquer.
|
||||||
|
addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
|
||||||
|
detail::AttributeUniquer::registerAttribute<T>(context);
|
||||||
|
}
|
||||||
|
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
|
||||||
|
|
||||||
|
/// Register a type instance with this dialect.
|
||||||
|
template <typename T> void addType() {
|
||||||
|
// Add this type to the dialect and register it with the uniquer.
|
||||||
|
addType(T::getTypeID(), AbstractType::get<T>(*this));
|
||||||
|
detail::TypeUniquer::registerType<T>(context);
|
||||||
|
}
|
||||||
|
void addType(TypeID typeID, AbstractType &&typeInfo);
|
||||||
|
|
||||||
/// The namespace of this dialect.
|
/// The namespace of this dialect.
|
||||||
StringRef name;
|
StringRef name;
|
||||||
|
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// This file enumerates the different dialects that define custom classes
|
|
||||||
// within the attribute or type system.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
DEFINE_SYM_KIND_RANGE(STANDARD)
|
|
||||||
DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
|
|
||||||
DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
|
|
||||||
DEFINE_SYM_KIND_RANGE(TENSORFLOW)
|
|
||||||
DEFINE_SYM_KIND_RANGE(LLVM)
|
|
||||||
DEFINE_SYM_KIND_RANGE(QUANTIZATION)
|
|
||||||
DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
|
|
||||||
DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect
|
|
||||||
DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect
|
|
||||||
|
|
||||||
// The following ranges are reserved for experimenting with MLIR dialects in a
|
|
||||||
// private context without having to register them here.
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
|
|
||||||
DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
|
|
||||||
|
|
||||||
#undef DEFINE_SYM_KIND_RANGE
|
|
|
@ -756,7 +756,7 @@ public:
|
||||||
/// all attributes of the given kind in the form : <alias>[0-9]+. These
|
/// all attributes of the given kind in the form : <alias>[0-9]+. These
|
||||||
/// aliases must not contain `.`.
|
/// aliases must not contain `.`.
|
||||||
virtual void getAttributeKindAliases(
|
virtual void getAttributeKindAliases(
|
||||||
SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
|
SmallVectorImpl<std::pair<TypeID, StringRef>> &aliases) const {}
|
||||||
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
|
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
|
||||||
/// end with a numeric digit([0-9]+).
|
/// end with a numeric digit([0-9]+).
|
||||||
virtual void getAttributeAliases(
|
virtual void getAttributeAliases(
|
||||||
|
|
|
@ -38,33 +38,6 @@ struct TupleTypeStorage;
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
namespace StandardTypes {
|
|
||||||
enum Kind {
|
|
||||||
// Floating point.
|
|
||||||
BF16 = Type::Kind::FIRST_STANDARD_TYPE,
|
|
||||||
F16,
|
|
||||||
F32,
|
|
||||||
F64,
|
|
||||||
FIRST_FLOATING_POINT_TYPE = BF16,
|
|
||||||
LAST_FLOATING_POINT_TYPE = F64,
|
|
||||||
|
|
||||||
// Target pointer sized integer, used (e.g.) in affine mappings.
|
|
||||||
Index,
|
|
||||||
|
|
||||||
// Derived types.
|
|
||||||
Integer,
|
|
||||||
Vector,
|
|
||||||
RankedTensor,
|
|
||||||
UnrankedTensor,
|
|
||||||
MemRef,
|
|
||||||
UnrankedMemRef,
|
|
||||||
Complex,
|
|
||||||
Tuple,
|
|
||||||
None,
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace StandardTypes
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ComplexType
|
// ComplexType
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -82,29 +82,29 @@ public:
|
||||||
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
|
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
|
||||||
/// Get or create a new ConcreteT instance within the ctx. This
|
/// Get or create a new ConcreteT instance within the ctx. This
|
||||||
/// function is guaranteed to return a non null object and will assert if
|
/// function is guaranteed to return a non null object and will assert if
|
||||||
/// the arguments provided are invalid.
|
/// the arguments provided are invalid.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
|
static ConcreteT get(MLIRContext *ctx, Args... args) {
|
||||||
// Ensure that the invariants are correct for construction.
|
// Ensure that the invariants are correct for construction.
|
||||||
assert(succeeded(ConcreteT::verifyConstructionInvariants(
|
assert(succeeded(ConcreteT::verifyConstructionInvariants(
|
||||||
generateUnknownStorageLocation(ctx), args...)));
|
generateUnknownStorageLocation(ctx), args...)));
|
||||||
return UniquerT::template get<ConcreteT>(ctx, kind, args...);
|
return UniquerT::template get<ConcreteT>(ctx, args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create a new ConcreteT instance within the ctx, defined at
|
/// Get or create a new ConcreteT instance within the ctx, defined at
|
||||||
/// the given, potentially unknown, location. If the arguments provided are
|
/// the given, potentially unknown, location. If the arguments provided are
|
||||||
/// invalid then emit errors and return a null object.
|
/// invalid then emit errors and return a null object.
|
||||||
template <typename LocationT, typename... Args>
|
template <typename LocationT, typename... Args>
|
||||||
static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
|
static ConcreteT getChecked(LocationT loc, Args... args) {
|
||||||
// If the construction invariants fail then we return a null attribute.
|
// If the construction invariants fail then we return a null attribute.
|
||||||
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
|
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
|
||||||
return ConcreteT();
|
return ConcreteT();
|
||||||
return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
|
return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
/// Mutate the current storage instance. This will not change the unique key.
|
/// Mutate the current storage instance. This will not change the unique key.
|
||||||
/// The arguments are forwarded to 'ConcreteT::mutate'.
|
/// The arguments are forwarded to 'ConcreteT::mutate'.
|
||||||
template <typename... Args> LogicalResult mutate(Args &&...args) {
|
template <typename... Args> LogicalResult mutate(Args &&...args) {
|
||||||
|
|
|
@ -121,15 +121,23 @@ namespace detail {
|
||||||
/// A utility class to get, or create, unique instances of types within an
|
/// A utility class to get, or create, unique instances of types within an
|
||||||
/// MLIRContext. This class manages all creation and uniquing of types.
|
/// MLIRContext. This class manages all creation and uniquing of types.
|
||||||
struct TypeUniquer {
|
struct TypeUniquer {
|
||||||
/// Get an uniqued instance of a type T.
|
/// Get an uniqued instance of a parametric type T.
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
static typename std::enable_if_t<
|
||||||
|
!std::is_same<typename T::ImplType, TypeStorage>::value, T>
|
||||||
|
get(MLIRContext *ctx, Args &&...args) {
|
||||||
return ctx->getTypeUniquer().get<typename T::ImplType>(
|
return ctx->getTypeUniquer().get<typename T::ImplType>(
|
||||||
T::getTypeID(),
|
|
||||||
[&](TypeStorage *storage) {
|
[&](TypeStorage *storage) {
|
||||||
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
|
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
|
||||||
},
|
},
|
||||||
kind, std::forward<Args>(args)...);
|
T::getTypeID(), std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
/// Get an uniqued instance of a singleton type T.
|
||||||
|
template <typename T>
|
||||||
|
static typename std::enable_if_t<
|
||||||
|
std::is_same<typename T::ImplType, TypeStorage>::value, T>
|
||||||
|
get(MLIRContext *ctx) {
|
||||||
|
return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Change the mutable component of the given type instance in the provided
|
/// Change the mutable component of the given type instance in the provided
|
||||||
|
@ -141,6 +149,25 @@ struct TypeUniquer {
|
||||||
return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
|
return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
|
||||||
std::forward<Args>(args)...);
|
std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Register a parametric type instance T with the uniquer.
|
||||||
|
template <typename T>
|
||||||
|
static typename std::enable_if_t<
|
||||||
|
!std::is_same<typename T::ImplType, TypeStorage>::value>
|
||||||
|
registerType(MLIRContext *ctx) {
|
||||||
|
ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
|
||||||
|
T::getTypeID());
|
||||||
|
}
|
||||||
|
/// Register a singleton type instance T with the uniquer.
|
||||||
|
template <typename T>
|
||||||
|
static typename std::enable_if_t<
|
||||||
|
std::is_same<typename T::ImplType, TypeStorage>::value>
|
||||||
|
registerType(MLIRContext *ctx) {
|
||||||
|
ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
|
||||||
|
T::getTypeID(), [&](TypeStorage *storage) {
|
||||||
|
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
|
||||||
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
|
|
@ -34,11 +34,11 @@ struct OpaqueTypeStorage;
|
||||||
///
|
///
|
||||||
/// Some types are "primitives" meaning they do not have any parameters, for
|
/// Some types are "primitives" meaning they do not have any parameters, for
|
||||||
/// example the Index type. Parametric types have additional information that
|
/// example the Index type. Parametric types have additional information that
|
||||||
/// differentiates the types of the same kind between them, for example the
|
/// differentiates the types of the same class, for example the Integer type has
|
||||||
/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
|
/// bitwidth, making i8 and i16 belong to the same kind by be different
|
||||||
/// different instances of the IntegerType. Type parameters are part of the
|
/// instances of the IntegerType. Type parameters are part of the unique
|
||||||
/// unique immutable key. The mutable component of the type can be modified
|
/// immutable key. The mutable component of the type can be modified after the
|
||||||
/// after the type is created, but cannot affect the identity of the type.
|
/// type is created, but cannot affect the identity of the type.
|
||||||
///
|
///
|
||||||
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
|
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
|
||||||
///
|
///
|
||||||
|
@ -53,20 +53,19 @@ struct OpaqueTypeStorage;
|
||||||
/// * This method is expected to return failure if a type cannot be
|
/// * This method is expected to return failure if a type cannot be
|
||||||
/// constructed with 'args', success otherwise.
|
/// constructed with 'args', success otherwise.
|
||||||
/// * 'args' must correspond with the arguments passed into the
|
/// * 'args' must correspond with the arguments passed into the
|
||||||
/// 'TypeBase::get' call after the type kind.
|
/// 'TypeBase::get' call.
|
||||||
///
|
///
|
||||||
///
|
///
|
||||||
/// Type storage objects inherit from TypeStorage and contain the following:
|
/// Type storage objects inherit from TypeStorage and contain the following:
|
||||||
/// - The type kind (for LLVM-style RTTI).
|
|
||||||
/// - The dialect that defined the type.
|
/// - The dialect that defined the type.
|
||||||
/// - Any parameters of the type.
|
/// - Any parameters of the type.
|
||||||
/// - An optional mutable component.
|
/// - An optional mutable component.
|
||||||
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
|
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
|
||||||
/// Parametric storage types must derive TypeStorage and respect the following:
|
/// Parametric storage types must derive TypeStorage and respect the following:
|
||||||
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
||||||
/// instance of the type within its kind.
|
/// instance of the type.
|
||||||
/// * The key type must be constructible from the values passed into the
|
/// * The key type must be constructible from the values passed into the
|
||||||
/// detail::TypeUniquer::get call after the type kind.
|
/// detail::TypeUniquer::get call.
|
||||||
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
|
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
|
||||||
/// storage class must define a hashing method:
|
/// storage class must define a hashing method:
|
||||||
/// 'static unsigned hashKey(const KeyTy &)'
|
/// 'static unsigned hashKey(const KeyTy &)'
|
||||||
|
@ -84,23 +83,6 @@ struct OpaqueTypeStorage;
|
||||||
// the key.
|
// the key.
|
||||||
class Type {
|
class Type {
|
||||||
public:
|
public:
|
||||||
/// Integer identifier for all the concrete type kinds.
|
|
||||||
/// Note: This is not an enum class as each dialect will likely define a
|
|
||||||
/// separate enumeration for the specific types that they define. Not being an
|
|
||||||
/// enum class also simplifies the handling of type kinds by not requiring
|
|
||||||
/// casts for each use.
|
|
||||||
enum Kind {
|
|
||||||
// Builtin types.
|
|
||||||
Function,
|
|
||||||
Opaque,
|
|
||||||
LAST_BUILTIN_TYPE = Opaque,
|
|
||||||
|
|
||||||
// Reserve type kinds for dialect specific type system extensions.
|
|
||||||
#define DEFINE_SYM_KIND_RANGE(Dialect) \
|
|
||||||
FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
|
|
||||||
#include "DialectSymbolRegistry.def"
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Utility class for implementing types.
|
/// Utility class for implementing types.
|
||||||
template <typename ConcreteType, typename BaseType, typename StorageType,
|
template <typename ConcreteType, typename BaseType, typename StorageType,
|
||||||
template <typename T> class... Traits>
|
template <typename T> class... Traits>
|
||||||
|
@ -136,9 +118,6 @@ public:
|
||||||
/// dynamic type casting.
|
/// dynamic type casting.
|
||||||
TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
|
TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
|
||||||
|
|
||||||
/// Return the classification for this type.
|
|
||||||
unsigned getKind() const;
|
|
||||||
|
|
||||||
/// Return the LLVMContext in which this type was uniqued.
|
/// Return the LLVMContext in which this type was uniqued.
|
||||||
MLIRContext *getContext() const;
|
MLIRContext *getContext() const;
|
||||||
|
|
||||||
|
|
|
@ -11,12 +11,11 @@
|
||||||
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Support/TypeID.h"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/Support/Allocator.h"
|
#include "llvm/Support/Allocator.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class TypeID;
|
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
struct StorageUniquerImpl;
|
struct StorageUniquerImpl;
|
||||||
|
|
||||||
|
@ -29,22 +28,19 @@ template <typename ImplTy, typename T>
|
||||||
using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
|
using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/// A utility class to get, or create instances of storage classes. These
|
/// A utility class to get or create instances of "storage classes". These
|
||||||
/// storage classes must respect the following constraints:
|
/// storage classes must derive from 'StorageUniquer::BaseStorage'.
|
||||||
/// - Derive from StorageUniquer::BaseStorage.
|
|
||||||
/// - Provide an unsigned 'kind' value to be used as part of the unique'ing
|
|
||||||
/// process.
|
|
||||||
///
|
///
|
||||||
/// For non-parametric storage classes, i.e. those that are solely uniqued by
|
/// For non-parametric storage classes, i.e. singleton classes, nothing else is
|
||||||
/// their kind, nothing else is needed. Instances of these classes can be
|
/// needed. Instances of these classes can be created by calling `get` without
|
||||||
/// created by calling `get` without trailing arguments.
|
/// trailing arguments.
|
||||||
///
|
///
|
||||||
/// Otherwise, the parametric storage classes may be created with `get`,
|
/// Otherwise, the parametric storage classes may be created with `get`,
|
||||||
/// and must respect the following:
|
/// and must respect the following:
|
||||||
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
||||||
/// instance of the storage class within its kind.
|
/// instance of the storage class.
|
||||||
/// * The key type must be constructible from the values passed into the
|
/// * The key type must be constructible from the values passed into the
|
||||||
/// getComplex call after the kind.
|
/// getComplex call.
|
||||||
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
|
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
|
||||||
/// storage class must define a hashing method:
|
/// storage class must define a hashing method:
|
||||||
/// 'static unsigned hashKey(const KeyTy &)'
|
/// 'static unsigned hashKey(const KeyTy &)'
|
||||||
|
@ -83,32 +79,11 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
|
||||||
/// class.
|
/// class.
|
||||||
class StorageUniquer {
|
class StorageUniquer {
|
||||||
public:
|
public:
|
||||||
StorageUniquer();
|
|
||||||
~StorageUniquer();
|
|
||||||
|
|
||||||
/// Set the flag specifying if multi-threading is disabled within the uniquer.
|
|
||||||
void disableMultithreading(bool disable = true);
|
|
||||||
|
|
||||||
/// Register a new storage object with this uniquer using the given unique
|
|
||||||
/// type id.
|
|
||||||
void registerStorageType(TypeID id);
|
|
||||||
|
|
||||||
/// This class acts as the base storage that all storage classes must derived
|
/// This class acts as the base storage that all storage classes must derived
|
||||||
/// from.
|
/// from.
|
||||||
class BaseStorage {
|
class BaseStorage {
|
||||||
public:
|
|
||||||
/// Get the kind classification of this storage.
|
|
||||||
unsigned getKind() const { return kind; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BaseStorage() : kind(0) {}
|
BaseStorage() = default;
|
||||||
|
|
||||||
private:
|
|
||||||
/// Allow access to the kind field.
|
|
||||||
friend detail::StorageUniquerImpl;
|
|
||||||
|
|
||||||
/// Classification of the subclass, used for type checking.
|
|
||||||
unsigned kind;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This is a utility allocator used to allocate memory for instances of
|
/// This is a utility allocator used to allocate memory for instances of
|
||||||
|
@ -145,19 +120,61 @@ public:
|
||||||
llvm::BumpPtrAllocator allocator;
|
llvm::BumpPtrAllocator allocator;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
|
StorageUniquer();
|
||||||
/// that can be used to initialize a newly inserted storage instance. This
|
~StorageUniquer();
|
||||||
/// function is used for derived types that have complex storage or uniquing
|
|
||||||
/// constraints.
|
|
||||||
template <typename Storage, typename Arg, typename... Args>
|
|
||||||
Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
|
|
||||||
unsigned kind, Arg &&arg, Args &&...args) {
|
|
||||||
// Construct a value of the derived key type.
|
|
||||||
auto derivedKey =
|
|
||||||
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
|
|
||||||
|
|
||||||
// Create a hash of the kind and the derived key.
|
/// Set the flag specifying if multi-threading is disabled within the uniquer.
|
||||||
unsigned hashValue = getHash<Storage>(kind, derivedKey);
|
void disableMultithreading(bool disable = true);
|
||||||
|
|
||||||
|
/// Register a new parametric storage class, this is necessary to create
|
||||||
|
/// instances of this class type. `id` is the type identifier that will be
|
||||||
|
/// used to identify this type when creating instances of it via 'get'.
|
||||||
|
template <typename Storage> void registerParametricStorageType(TypeID id) {
|
||||||
|
registerParametricStorageTypeImpl(id);
|
||||||
|
}
|
||||||
|
/// Utility override when the storage type represents the type id.
|
||||||
|
template <typename Storage> void registerParametricStorageType() {
|
||||||
|
registerParametricStorageType<Storage>(TypeID::get<Storage>());
|
||||||
|
}
|
||||||
|
/// Register a new singleton storage class, this is necessary to get the
|
||||||
|
/// singletone instance. `id` is the type identifier that will be used to
|
||||||
|
/// access the singleton instance via 'get'. An optional initialization
|
||||||
|
/// function may also be provided to initialize the newly created storage
|
||||||
|
/// instance, and used when the singleton instance is created.
|
||||||
|
template <typename Storage>
|
||||||
|
void registerSingletonStorageType(TypeID id,
|
||||||
|
function_ref<void(Storage *)> initFn) {
|
||||||
|
auto ctorFn = [&](StorageAllocator &allocator) {
|
||||||
|
auto *storage = new (allocator.allocate<Storage>()) Storage();
|
||||||
|
if (initFn)
|
||||||
|
initFn(storage);
|
||||||
|
return storage;
|
||||||
|
};
|
||||||
|
registerSingletonImpl(id, ctorFn);
|
||||||
|
}
|
||||||
|
template <typename Storage> void registerSingletonStorageType(TypeID id) {
|
||||||
|
registerSingletonStorageType<Storage>(id, llvm::None);
|
||||||
|
}
|
||||||
|
/// Utility override when the storage type represents the type id.
|
||||||
|
template <typename Storage>
|
||||||
|
void registerSingletonStorageType(
|
||||||
|
function_ref<void(Storage *)> initFn = llvm::None) {
|
||||||
|
registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
|
||||||
|
/// registering the storage instance. 'initFn' is an optional parameter that
|
||||||
|
/// can be used to initialize a newly inserted storage instance. This function
|
||||||
|
/// is used for derived types that have complex storage or uniquing
|
||||||
|
/// constraints.
|
||||||
|
template <typename Storage, typename... Args>
|
||||||
|
Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
|
||||||
|
Args &&...args) {
|
||||||
|
// Construct a value of the derived key type.
|
||||||
|
auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
|
||||||
|
|
||||||
|
// Create a hash of the derived key.
|
||||||
|
unsigned hashValue = getHash<Storage>(derivedKey);
|
||||||
|
|
||||||
// Generate an equality function for the derived storage.
|
// Generate an equality function for the derived storage.
|
||||||
auto isEqual = [&derivedKey](const BaseStorage *existing) {
|
auto isEqual = [&derivedKey](const BaseStorage *existing) {
|
||||||
|
@ -174,29 +191,29 @@ public:
|
||||||
|
|
||||||
// Get an instance for the derived storage.
|
// Get an instance for the derived storage.
|
||||||
return static_cast<Storage *>(
|
return static_cast<Storage *>(
|
||||||
getImpl(id, kind, hashValue, isEqual, ctorFn));
|
getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
|
||||||
|
}
|
||||||
|
/// Utility override when the storage type represents the type id.
|
||||||
|
template <typename Storage, typename... Args>
|
||||||
|
Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
|
||||||
|
return get<Storage>(initFn, TypeID::get<Storage>(),
|
||||||
|
std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
|
/// Gets a uniqued instance of 'Storage' which is a singleton storage type.
|
||||||
/// that can be used to initialize a newly inserted storage instance. This
|
/// 'id' is the type id used when registering the storage instance.
|
||||||
/// function is used for derived types that use no additional storage or
|
template <typename Storage> Storage *get(TypeID id) {
|
||||||
/// uniquing outside of the kind.
|
return static_cast<Storage *>(getSingletonImpl(id));
|
||||||
template <typename Storage>
|
}
|
||||||
Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
|
/// Utility override when the storage type represents the type id.
|
||||||
unsigned kind) {
|
template <typename Storage> Storage *get() {
|
||||||
auto ctorFn = [&](StorageAllocator &allocator) {
|
return get<Storage>(TypeID::get<Storage>());
|
||||||
auto *storage = new (allocator.allocate<Storage>()) Storage();
|
|
||||||
if (initFn)
|
|
||||||
initFn(storage);
|
|
||||||
return storage;
|
|
||||||
};
|
|
||||||
return static_cast<Storage *>(getImpl(id, kind, ctorFn));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Changes the mutable component of 'storage' by forwarding the trailing
|
/// Changes the mutable component of 'storage' by forwarding the trailing
|
||||||
/// arguments to the 'mutate' function of the derived class.
|
/// arguments to the 'mutate' function of the derived class.
|
||||||
template <typename Storage, typename... Args>
|
template <typename Storage, typename... Args>
|
||||||
LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) {
|
LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
|
||||||
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
|
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
|
||||||
return static_cast<Storage &>(*storage).mutate(
|
return static_cast<Storage &>(*storage).mutate(
|
||||||
allocator, std::forward<Args>(args)...);
|
allocator, std::forward<Args>(args)...);
|
||||||
|
@ -207,13 +224,13 @@ public:
|
||||||
/// Erases a uniqued instance of 'Storage'. This function is used for derived
|
/// Erases a uniqued instance of 'Storage'. This function is used for derived
|
||||||
/// types that have complex storage or uniquing constraints.
|
/// types that have complex storage or uniquing constraints.
|
||||||
template <typename Storage, typename Arg, typename... Args>
|
template <typename Storage, typename Arg, typename... Args>
|
||||||
void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) {
|
void erase(TypeID id, Arg &&arg, Args &&...args) {
|
||||||
// Construct a value of the derived key type.
|
// Construct a value of the derived key type.
|
||||||
auto derivedKey =
|
auto derivedKey =
|
||||||
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
|
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
|
||||||
|
|
||||||
// Create a hash of the kind and the derived key.
|
// Create a hash of the derived key.
|
||||||
unsigned hashValue = getHash<Storage>(kind, derivedKey);
|
unsigned hashValue = getHash<Storage>(derivedKey);
|
||||||
|
|
||||||
// Generate an equality function for the derived storage.
|
// Generate an equality function for the derived storage.
|
||||||
auto isEqual = [&derivedKey](const BaseStorage *existing) {
|
auto isEqual = [&derivedKey](const BaseStorage *existing) {
|
||||||
|
@ -221,32 +238,42 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
// Attempt to erase the storage instance.
|
// Attempt to erase the storage instance.
|
||||||
eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) {
|
eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) {
|
||||||
static_cast<Storage *>(storage)->cleanup();
|
static_cast<Storage *>(storage)->cleanup();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Implementation for getting/creating an instance of a derived type with
|
/// Implementation for getting/creating an instance of a derived type with
|
||||||
/// complex storage.
|
/// parametric storage.
|
||||||
BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue,
|
BaseStorage *getParametricStorageTypeImpl(
|
||||||
function_ref<bool(const BaseStorage *)> isEqual,
|
TypeID id, unsigned hashValue,
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
|
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
||||||
|
|
||||||
/// Implementation for getting/creating an instance of a derived type with
|
/// Implementation for registering an instance of a derived type with
|
||||||
/// default storage.
|
/// parametric storage.
|
||||||
BaseStorage *getImpl(const TypeID &id, unsigned kind,
|
void registerParametricStorageTypeImpl(TypeID id);
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
|
||||||
|
/// Implementation for getting an instance of a derived type with default
|
||||||
|
/// storage.
|
||||||
|
BaseStorage *getSingletonImpl(TypeID id);
|
||||||
|
|
||||||
|
/// Implementation for registering an instance of a derived type with default
|
||||||
|
/// storage.
|
||||||
|
void
|
||||||
|
registerSingletonImpl(TypeID id,
|
||||||
|
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
||||||
|
|
||||||
/// Implementation for erasing an instance of a derived type with complex
|
/// Implementation for erasing an instance of a derived type with complex
|
||||||
/// storage.
|
/// storage.
|
||||||
void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue,
|
void eraseImpl(TypeID id, unsigned hashValue,
|
||||||
function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
function_ref<void(BaseStorage *)> cleanupFn);
|
function_ref<void(BaseStorage *)> cleanupFn);
|
||||||
|
|
||||||
/// Implementation for mutating an instance of a derived storage.
|
/// Implementation for mutating an instance of a derived storage.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
mutateImpl(const TypeID &id,
|
mutateImpl(TypeID id,
|
||||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn);
|
function_ref<LogicalResult(StorageAllocator &)> mutationFn);
|
||||||
|
|
||||||
/// The internal implementation class.
|
/// The internal implementation class.
|
||||||
|
@ -276,27 +303,26 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Key and Kind Hashing
|
// Key Hashing
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
|
/// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if
|
||||||
/// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
|
/// there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
|
||||||
template <typename ImplTy, typename DerivedKey>
|
template <typename ImplTy, typename DerivedKey>
|
||||||
static typename std::enable_if<
|
static typename std::enable_if<
|
||||||
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
|
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
|
||||||
::llvm::hash_code>::type
|
::llvm::hash_code>::type
|
||||||
getHash(unsigned kind, const DerivedKey &derivedKey) {
|
getHash(const DerivedKey &derivedKey) {
|
||||||
return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
|
return ImplTy::hashKey(derivedKey);
|
||||||
}
|
}
|
||||||
/// If there is no 'ImplTy::hashKey' default to using the
|
/// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo'
|
||||||
/// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
|
/// definition for 'DerivedKey' for generating a hash.
|
||||||
template <typename ImplTy, typename DerivedKey>
|
template <typename ImplTy, typename DerivedKey>
|
||||||
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
|
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
|
||||||
ImplTy, DerivedKey>::value,
|
ImplTy, DerivedKey>::value,
|
||||||
::llvm::hash_code>::type
|
::llvm::hash_code>::type
|
||||||
getHash(unsigned kind, const DerivedKey &derivedKey) {
|
getHash(const DerivedKey &derivedKey) {
|
||||||
return llvm::hash_combine(
|
return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
|
||||||
kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
|
@ -264,14 +264,13 @@ bool LLVMArrayType::isValidElementType(LLVMType type) {
|
||||||
|
|
||||||
LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
|
LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
|
||||||
assert(elementType && "expected non-null subtype");
|
assert(elementType && "expected non-null subtype");
|
||||||
return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType,
|
return Base::get(elementType.getContext(), elementType, numElements);
|
||||||
numElements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
|
LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
|
||||||
unsigned numElements) {
|
unsigned numElements) {
|
||||||
assert(elementType && "expected non-null subtype");
|
assert(elementType && "expected non-null subtype");
|
||||||
return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements);
|
return Base::getChecked(loc, elementType, numElements);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
|
LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
|
||||||
|
@ -301,16 +300,14 @@ LLVMFunctionType LLVMFunctionType::get(LLVMType result,
|
||||||
ArrayRef<LLVMType> arguments,
|
ArrayRef<LLVMType> arguments,
|
||||||
bool isVarArg) {
|
bool isVarArg) {
|
||||||
assert(result && "expected non-null result");
|
assert(result && "expected non-null result");
|
||||||
return Base::get(result.getContext(), LLVMType::FunctionType, result,
|
return Base::get(result.getContext(), result, arguments, isVarArg);
|
||||||
arguments, isVarArg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
|
LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
|
||||||
ArrayRef<LLVMType> arguments,
|
ArrayRef<LLVMType> arguments,
|
||||||
bool isVarArg) {
|
bool isVarArg) {
|
||||||
assert(result && "expected non-null result");
|
assert(result && "expected non-null result");
|
||||||
return Base::getChecked(loc, LLVMType::FunctionType, result, arguments,
|
return Base::getChecked(loc, result, arguments, isVarArg);
|
||||||
isVarArg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMType LLVMFunctionType::getReturnType() {
|
LLVMType LLVMFunctionType::getReturnType() {
|
||||||
|
@ -347,11 +344,11 @@ LogicalResult LLVMFunctionType::verifyConstructionInvariants(
|
||||||
// Integer type.
|
// Integer type.
|
||||||
|
|
||||||
LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
|
LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
|
||||||
return Base::get(ctx, LLVMType::IntegerType, bitwidth);
|
return Base::get(ctx, bitwidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
|
LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
|
||||||
return Base::getChecked(loc, LLVMType::IntegerType, bitwidth);
|
return Base::getChecked(loc, bitwidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
|
unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
|
||||||
|
@ -374,13 +371,12 @@ bool LLVMPointerType::isValidElementType(LLVMType type) {
|
||||||
|
|
||||||
LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
|
LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
|
||||||
assert(pointee && "expected non-null subtype");
|
assert(pointee && "expected non-null subtype");
|
||||||
return Base::get(pointee.getContext(), LLVMType::PointerType, pointee,
|
return Base::get(pointee.getContext(), pointee, addressSpace);
|
||||||
addressSpace);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
|
LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
|
||||||
unsigned addressSpace) {
|
unsigned addressSpace) {
|
||||||
return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace);
|
return Base::getChecked(loc, pointee, addressSpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
|
LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
|
||||||
|
@ -405,32 +401,32 @@ bool LLVMStructType::isValidElementType(LLVMType type) {
|
||||||
|
|
||||||
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
|
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
|
||||||
StringRef name) {
|
StringRef name) {
|
||||||
return Base::get(context, LLVMType::StructType, name, /*opaque=*/false);
|
return Base::get(context, name, /*opaque=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
|
LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
|
||||||
StringRef name) {
|
StringRef name) {
|
||||||
return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false);
|
return Base::getChecked(loc, name, /*opaque=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
|
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
|
||||||
ArrayRef<LLVMType> types,
|
ArrayRef<LLVMType> types,
|
||||||
bool isPacked) {
|
bool isPacked) {
|
||||||
return Base::get(context, LLVMType::StructType, types, isPacked);
|
return Base::get(context, types, isPacked);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
|
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
|
||||||
ArrayRef<LLVMType> types,
|
ArrayRef<LLVMType> types,
|
||||||
bool isPacked) {
|
bool isPacked) {
|
||||||
return Base::getChecked(loc, LLVMType::StructType, types, isPacked);
|
return Base::getChecked(loc, types, isPacked);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
|
LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
|
||||||
return Base::get(context, LLVMType::StructType, name, /*opaque=*/true);
|
return Base::get(context, name, /*opaque=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
|
LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
|
||||||
return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true);
|
return Base::getChecked(loc, name, /*opaque=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
|
LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
|
||||||
|
@ -508,16 +504,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
|
||||||
LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
|
LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
|
||||||
unsigned numElements) {
|
unsigned numElements) {
|
||||||
assert(elementType && "expected non-null subtype");
|
assert(elementType && "expected non-null subtype");
|
||||||
return Base::get(elementType.getContext(), LLVMType::FixedVectorType,
|
return Base::get(elementType.getContext(), elementType, numElements);
|
||||||
elementType, numElements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
|
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
|
||||||
LLVMType elementType,
|
LLVMType elementType,
|
||||||
unsigned numElements) {
|
unsigned numElements) {
|
||||||
assert(elementType && "expected non-null subtype");
|
assert(elementType && "expected non-null subtype");
|
||||||
return Base::getChecked(loc, LLVMType::FixedVectorType, elementType,
|
return Base::getChecked(loc, elementType, numElements);
|
||||||
numElements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned LLVMFixedVectorType::getNumElements() {
|
unsigned LLVMFixedVectorType::getNumElements() {
|
||||||
|
@ -527,16 +521,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
|
||||||
LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
|
LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
|
||||||
unsigned minNumElements) {
|
unsigned minNumElements) {
|
||||||
assert(elementType && "expected non-null subtype");
|
assert(elementType && "expected non-null subtype");
|
||||||
return Base::get(elementType.getContext(), LLVMType::ScalableVectorType,
|
return Base::get(elementType.getContext(), elementType, minNumElements);
|
||||||
elementType, minNumElements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMScalableVectorType
|
LLVMScalableVectorType
|
||||||
LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
|
LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
|
||||||
unsigned minNumElements) {
|
unsigned minNumElements) {
|
||||||
assert(elementType && "expected non-null subtype");
|
assert(elementType && "expected non-null subtype");
|
||||||
return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType,
|
return Base::getChecked(loc, elementType, minNumElements);
|
||||||
minNumElements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned LLVMScalableVectorType::getMinNumElements() {
|
unsigned LLVMScalableVectorType::getMinNumElements() {
|
||||||
|
|
|
@ -204,8 +204,8 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
|
||||||
Type expressedType,
|
Type expressedType,
|
||||||
int64_t storageTypeMin,
|
int64_t storageTypeMin,
|
||||||
int64_t storageTypeMax) {
|
int64_t storageTypeMax) {
|
||||||
return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
|
return Base::get(storageType.getContext(), flags, storageType, expressedType,
|
||||||
storageType, expressedType, storageTypeMin, storageTypeMax);
|
storageTypeMin, storageTypeMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
|
AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
|
||||||
|
@ -213,8 +213,8 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
|
||||||
int64_t storageTypeMin,
|
int64_t storageTypeMin,
|
||||||
int64_t storageTypeMax,
|
int64_t storageTypeMax,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
|
return Base::getChecked(location, flags, storageType, expressedType,
|
||||||
expressedType, storageTypeMin, storageTypeMax);
|
storageTypeMin, storageTypeMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
|
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
|
||||||
|
@ -240,10 +240,8 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
|
||||||
int64_t zeroPoint,
|
int64_t zeroPoint,
|
||||||
int64_t storageTypeMin,
|
int64_t storageTypeMin,
|
||||||
int64_t storageTypeMax) {
|
int64_t storageTypeMax) {
|
||||||
return Base::get(storageType.getContext(),
|
return Base::get(storageType.getContext(), flags, storageType, expressedType,
|
||||||
QuantizationTypes::UniformQuantized, flags, storageType,
|
scale, zeroPoint, storageTypeMin, storageTypeMax);
|
||||||
expressedType, scale, zeroPoint, storageTypeMin,
|
|
||||||
storageTypeMax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
UniformQuantizedType
|
UniformQuantizedType
|
||||||
|
@ -251,9 +249,8 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
|
||||||
Type expressedType, double scale,
|
Type expressedType, double scale,
|
||||||
int64_t zeroPoint, int64_t storageTypeMin,
|
int64_t zeroPoint, int64_t storageTypeMin,
|
||||||
int64_t storageTypeMax, Location location) {
|
int64_t storageTypeMax, Location location) {
|
||||||
return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
|
return Base::getChecked(location, flags, storageType, expressedType, scale,
|
||||||
storageType, expressedType, scale, zeroPoint,
|
zeroPoint, storageTypeMin, storageTypeMax);
|
||||||
storageTypeMin, storageTypeMax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
|
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
|
||||||
|
@ -295,10 +292,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
|
||||||
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
||||||
int32_t quantizedDimension, int64_t storageTypeMin,
|
int32_t quantizedDimension, int64_t storageTypeMin,
|
||||||
int64_t storageTypeMax) {
|
int64_t storageTypeMax) {
|
||||||
return Base::get(storageType.getContext(),
|
return Base::get(storageType.getContext(), flags, storageType, expressedType,
|
||||||
QuantizationTypes::UniformQuantizedPerAxis, flags,
|
scales, zeroPoints, quantizedDimension, storageTypeMin,
|
||||||
storageType, expressedType, scales, zeroPoints,
|
storageTypeMax);
|
||||||
quantizedDimension, storageTypeMin, storageTypeMax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
||||||
|
@ -306,9 +302,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
||||||
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
||||||
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
|
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
|
return Base::getChecked(location, flags, storageType, expressedType, scales,
|
||||||
flags, storageType, expressedType, scales, zeroPoints,
|
zeroPoints, quantizedDimension, storageTypeMin,
|
||||||
quantizedDimension, storageTypeMin, storageTypeMax);
|
storageTypeMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
|
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
|
||||||
|
|
|
@ -13,11 +13,11 @@ using namespace mlir;
|
||||||
|
|
||||||
SDBMDialect::SDBMDialect(MLIRContext *context)
|
SDBMDialect::SDBMDialect(MLIRContext *context)
|
||||||
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
|
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
|
||||||
uniquer.registerStorageType(TypeID::get<detail::SDBMBinaryExprStorage>());
|
uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
|
||||||
uniquer.registerStorageType(TypeID::get<detail::SDBMConstantExprStorage>());
|
uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
|
||||||
uniquer.registerStorageType(TypeID::get<detail::SDBMDiffExprStorage>());
|
uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
|
||||||
uniquer.registerStorageType(TypeID::get<detail::SDBMNegExprStorage>());
|
uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
|
||||||
uniquer.registerStorageType(TypeID::get<detail::SDBMTermExprStorage>());
|
uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMDialect::~SDBMDialect() = default;
|
SDBMDialect::~SDBMDialect() = default;
|
||||||
|
|
|
@ -246,7 +246,6 @@ SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
|
||||||
|
|
||||||
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
||||||
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
||||||
TypeID::get<detail::SDBMBinaryExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -533,9 +532,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
|
||||||
assert(rhs && "expected SDBM dimension");
|
assert(rhs && "expected SDBM dimension");
|
||||||
|
|
||||||
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
|
||||||
return uniquer.get<detail::SDBMDiffExprStorage>(
|
return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
|
||||||
TypeID::get<detail::SDBMDiffExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
|
||||||
|
@ -575,7 +572,6 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
|
||||||
|
|
||||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||||
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
return uniquer.get<detail::SDBMBinaryExprStorage>(
|
||||||
TypeID::get<detail::SDBMBinaryExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
|
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
|
||||||
stripeFactor);
|
stripeFactor);
|
||||||
}
|
}
|
||||||
|
@ -611,8 +607,7 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
|
||||||
|
|
||||||
StorageUniquer &uniquer = dialect->getUniquer();
|
StorageUniquer &uniquer = dialect->getUniquer();
|
||||||
return uniquer.get<detail::SDBMTermExprStorage>(
|
return uniquer.get<detail::SDBMTermExprStorage>(
|
||||||
TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
|
assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
|
||||||
static_cast<unsigned>(SDBMExprKind::DimId), position);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -628,8 +623,7 @@ SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
|
||||||
|
|
||||||
StorageUniquer &uniquer = dialect->getUniquer();
|
StorageUniquer &uniquer = dialect->getUniquer();
|
||||||
return uniquer.get<detail::SDBMTermExprStorage>(
|
return uniquer.get<detail::SDBMTermExprStorage>(
|
||||||
TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
|
assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
|
||||||
static_cast<unsigned>(SDBMExprKind::SymbolId), position);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -644,9 +638,7 @@ SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
|
||||||
};
|
};
|
||||||
|
|
||||||
StorageUniquer &uniquer = dialect->getUniquer();
|
StorageUniquer &uniquer = dialect->getUniquer();
|
||||||
return uniquer.get<detail::SDBMConstantExprStorage>(
|
return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
|
||||||
TypeID::get<detail::SDBMConstantExprStorage>(), assignCtx,
|
|
||||||
static_cast<unsigned>(SDBMExprKind::Constant), value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t SDBMConstantExpr::getValue() const {
|
int64_t SDBMConstantExpr::getValue() const {
|
||||||
|
@ -661,9 +653,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
|
||||||
assert(var && "expected non-null SDBM direct expression");
|
assert(var && "expected non-null SDBM direct expression");
|
||||||
|
|
||||||
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
StorageUniquer &uniquer = var.getDialect()->getUniquer();
|
||||||
return uniquer.get<detail::SDBMNegExprStorage>(
|
return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
|
||||||
TypeID::get<detail::SDBMNegExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
SDBMDirectExpr SDBMNegExpr::getVar() const {
|
||||||
|
|
|
@ -25,27 +25,28 @@ namespace detail {
|
||||||
|
|
||||||
// Base storage class for SDBMExpr.
|
// Base storage class for SDBMExpr.
|
||||||
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
|
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
|
||||||
SDBMExprKind getKind() {
|
SDBMExprKind getKind() { return kind; }
|
||||||
return static_cast<SDBMExprKind>(BaseStorage::getKind());
|
|
||||||
}
|
|
||||||
|
|
||||||
SDBMDialect *dialect;
|
SDBMDialect *dialect;
|
||||||
|
SDBMExprKind kind;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Storage class for SDBM sum and stripe expressions.
|
// Storage class for SDBM sum and stripe expressions.
|
||||||
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
struct SDBMBinaryExprStorage : public SDBMExprStorage {
|
||||||
using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
|
using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
|
return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
|
||||||
|
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
static SDBMBinaryExprStorage *
|
static SDBMBinaryExprStorage *
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
|
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
|
||||||
result->lhs = std::get<0>(key);
|
result->lhs = std::get<1>(key);
|
||||||
result->rhs = std::get<1>(key);
|
result->rhs = std::get<2>(key);
|
||||||
result->dialect = result->lhs.getDialect();
|
result->dialect = result->lhs.getDialect();
|
||||||
|
result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +68,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
|
||||||
result->lhs = std::get<0>(key);
|
result->lhs = std::get<0>(key);
|
||||||
result->rhs = std::get<1>(key);
|
result->rhs = std::get<1>(key);
|
||||||
result->dialect = result->lhs.getDialect();
|
result->dialect = result->lhs.getDialect();
|
||||||
|
result->kind = SDBMExprKind::Diff;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,6 +86,7 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<SDBMConstantExprStorage>();
|
auto *result = allocator.allocate<SDBMConstantExprStorage>();
|
||||||
result->constant = key;
|
result->constant = key;
|
||||||
|
result->kind = SDBMExprKind::Constant;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,14 +95,18 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
|
||||||
|
|
||||||
// Storage class for SDBM dimension and symbol expressions.
|
// Storage class for SDBM dimension and symbol expressions.
|
||||||
struct SDBMTermExprStorage : public SDBMExprStorage {
|
struct SDBMTermExprStorage : public SDBMExprStorage {
|
||||||
using KeyTy = unsigned;
|
using KeyTy = std::pair<unsigned, unsigned>;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const { return position == key; }
|
bool operator==(const KeyTy &key) const {
|
||||||
|
return kind == static_cast<SDBMExprKind>(key.first) &&
|
||||||
|
position == key.second;
|
||||||
|
}
|
||||||
|
|
||||||
static SDBMTermExprStorage *
|
static SDBMTermExprStorage *
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<SDBMTermExprStorage>();
|
auto *result = allocator.allocate<SDBMTermExprStorage>();
|
||||||
result->position = key;
|
result->kind = static_cast<SDBMExprKind>(key.first);
|
||||||
|
result->position = key.second;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,6 +124,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage {
|
||||||
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
auto *result = allocator.allocate<SDBMNegExprStorage>();
|
||||||
result->expr = key;
|
result->expr = key;
|
||||||
result->dialect = key.getDialect();
|
result->dialect = key.getDialect();
|
||||||
|
result->kind = SDBMExprKind::Neg;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -120,8 +120,7 @@ spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
|
||||||
IntegerAttr storageClass) {
|
IntegerAttr storageClass) {
|
||||||
assert(descriptorSet && binding);
|
assert(descriptorSet && binding);
|
||||||
MLIRContext *context = descriptorSet.getContext();
|
MLIRContext *context = descriptorSet.getContext();
|
||||||
return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
|
return Base::get(context, descriptorSet, binding, storageClass);
|
||||||
binding, storageClass);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef spirv::InterfaceVarABIAttr::getKindName() {
|
StringRef spirv::InterfaceVarABIAttr::getKindName() {
|
||||||
|
@ -195,8 +194,7 @@ spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
|
||||||
ArrayAttr extensions) {
|
ArrayAttr extensions) {
|
||||||
assert(version && capabilities && extensions);
|
assert(version && capabilities && extensions);
|
||||||
MLIRContext *context = version.getContext();
|
MLIRContext *context = version.getContext();
|
||||||
return Base::get(context, spirv::AttrKind::VerCapExt, version, capabilities,
|
return Base::get(context, version, capabilities, extensions);
|
||||||
extensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
|
StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
|
||||||
|
@ -272,7 +270,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
|
||||||
DictionaryAttr limits) {
|
DictionaryAttr limits) {
|
||||||
assert(triple && limits && "expected valid triple and limits");
|
assert(triple && limits && "expected valid triple and limits");
|
||||||
MLIRContext *context = triple.getContext();
|
MLIRContext *context = triple.getContext();
|
||||||
return Base::get(context, spirv::AttrKind::TargetEnv, triple, limits);
|
return Base::get(context, triple, limits);
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
|
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
|
||||||
|
|
|
@ -124,15 +124,14 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
|
||||||
|
|
||||||
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
|
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
|
||||||
assert(elementCount && "ArrayType needs at least one element");
|
assert(elementCount && "ArrayType needs at least one element");
|
||||||
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
|
return Base::get(elementType.getContext(), elementType, elementCount,
|
||||||
elementCount, /*stride=*/0);
|
/*stride=*/0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
|
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
|
||||||
unsigned stride) {
|
unsigned stride) {
|
||||||
assert(elementCount && "ArrayType needs at least one element");
|
assert(elementCount && "ArrayType needs at least one element");
|
||||||
return Base::get(elementType.getContext(), TypeKind::Array, elementType,
|
return Base::get(elementType.getContext(), elementType, elementCount, stride);
|
||||||
elementCount, stride);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
|
unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
|
||||||
|
@ -285,8 +284,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
|
||||||
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
|
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
|
||||||
Scope scope, unsigned rows,
|
Scope scope, unsigned rows,
|
||||||
unsigned columns) {
|
unsigned columns) {
|
||||||
return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
|
return Base::get(elementType.getContext(), elementType, scope, rows, columns);
|
||||||
elementType, scope, rows, columns);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type CooperativeMatrixNVType::getElementType() const {
|
Type CooperativeMatrixNVType::getElementType() const {
|
||||||
|
@ -389,7 +387,7 @@ ImageType
|
||||||
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
|
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
|
||||||
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
|
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
|
||||||
value) {
|
value) {
|
||||||
return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
|
return Base::get(std::get<0>(value).getContext(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type ImageType::getElementType() const { return getImpl()->elementType; }
|
Type ImageType::getElementType() const { return getImpl()->elementType; }
|
||||||
|
@ -453,8 +451,7 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
|
||||||
};
|
};
|
||||||
|
|
||||||
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
|
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
|
||||||
return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType,
|
return Base::get(pointeeType.getContext(), pointeeType, storageClass);
|
||||||
storageClass);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
|
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
|
||||||
|
@ -511,13 +508,11 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
|
||||||
};
|
};
|
||||||
|
|
||||||
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
|
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
|
||||||
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
|
return Base::get(elementType.getContext(), elementType, /*stride=*/0);
|
||||||
elementType, /*stride=*/0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
|
RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
|
||||||
return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
|
return Base::get(elementType.getContext(), elementType, stride);
|
||||||
elementType, stride);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
|
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
|
||||||
|
@ -846,12 +841,12 @@ StructType::get(ArrayRef<Type> memberTypes,
|
||||||
SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
|
SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
|
||||||
memberDecorations.begin(), memberDecorations.end());
|
memberDecorations.begin(), memberDecorations.end());
|
||||||
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
|
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
|
||||||
return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
|
return Base::get(memberTypes.vec().front().getContext(), memberTypes,
|
||||||
memberTypes, offsetInfo, sortedDecorations);
|
offsetInfo, sortedDecorations);
|
||||||
}
|
}
|
||||||
|
|
||||||
StructType StructType::getEmpty(MLIRContext *context) {
|
StructType StructType::getEmpty(MLIRContext *context) {
|
||||||
return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
|
return Base::get(context, ArrayRef<Type>(),
|
||||||
ArrayRef<StructType::OffsetInfo>(),
|
ArrayRef<StructType::OffsetInfo>(),
|
||||||
ArrayRef<StructType::MemberDecorationInfo>());
|
ArrayRef<StructType::MemberDecorationInfo>());
|
||||||
}
|
}
|
||||||
|
@ -946,13 +941,12 @@ struct spirv::detail::MatrixTypeStorage : public TypeStorage {
|
||||||
};
|
};
|
||||||
|
|
||||||
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
|
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
|
||||||
return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
|
return Base::get(columnType.getContext(), columnType, columnCount);
|
||||||
columnCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
|
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
|
return Base::getChecked(location, columnType, columnCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
|
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
|
||||||
|
|
|
@ -20,9 +20,7 @@ using namespace mlir::detail;
|
||||||
|
|
||||||
MLIRContext *AffineExpr::getContext() const { return expr->context; }
|
MLIRContext *AffineExpr::getContext() const { return expr->context; }
|
||||||
|
|
||||||
AffineExprKind AffineExpr::getKind() const {
|
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
|
||||||
return static_cast<AffineExprKind>(expr->getKind());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Walk all of the AffineExprs in this subgraph in postorder.
|
/// Walk all of the AffineExprs in this subgraph in postorder.
|
||||||
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
|
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
|
||||||
|
@ -449,8 +447,7 @@ static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
|
||||||
|
|
||||||
StorageUniquer &uniquer = context->getAffineUniquer();
|
StorageUniquer &uniquer = context->getAffineUniquer();
|
||||||
return uniquer.get<AffineDimExprStorage>(
|
return uniquer.get<AffineDimExprStorage>(
|
||||||
TypeID::get<AffineDimExprStorage>(), assignCtx,
|
assignCtx, static_cast<unsigned>(kind), position);
|
||||||
static_cast<unsigned>(kind), position);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
|
AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
|
||||||
|
@ -484,9 +481,7 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
|
||||||
};
|
};
|
||||||
|
|
||||||
StorageUniquer &uniquer = context->getAffineUniquer();
|
StorageUniquer &uniquer = context->getAffineUniquer();
|
||||||
return uniquer.get<AffineConstantExprStorage>(
|
return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
|
||||||
TypeID::get<AffineConstantExprStorage>(), assignCtx,
|
|
||||||
static_cast<unsigned>(AffineExprKind::Constant), constant);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Simplify add expression. Return nullptr if it can't be simplified.
|
/// Simplify add expression. Return nullptr if it can't be simplified.
|
||||||
|
@ -594,7 +589,6 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
|
||||||
|
|
||||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
|
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -655,7 +649,6 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
|
||||||
|
|
||||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
|
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -722,7 +715,6 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
|
||||||
|
|
||||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
|
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
|
||||||
other);
|
other);
|
||||||
}
|
}
|
||||||
|
@ -766,7 +758,6 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
|
||||||
|
|
||||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
|
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
|
||||||
other);
|
other);
|
||||||
}
|
}
|
||||||
|
@ -814,7 +805,6 @@ AffineExpr AffineExpr::operator%(AffineExpr other) const {
|
||||||
|
|
||||||
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
StorageUniquer &uniquer = getContext()->getAffineUniquer();
|
||||||
return uniquer.get<AffineBinaryOpExprStorage>(
|
return uniquer.get<AffineBinaryOpExprStorage>(
|
||||||
TypeID::get<AffineBinaryOpExprStorage>(),
|
|
||||||
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
|
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,21 +27,24 @@ namespace detail {
|
||||||
/// Base storage class appearing in an affine expression.
|
/// Base storage class appearing in an affine expression.
|
||||||
struct AffineExprStorage : public StorageUniquer::BaseStorage {
|
struct AffineExprStorage : public StorageUniquer::BaseStorage {
|
||||||
MLIRContext *context;
|
MLIRContext *context;
|
||||||
|
AffineExprKind kind;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A binary operation appearing in an affine expression.
|
/// A binary operation appearing in an affine expression.
|
||||||
struct AffineBinaryOpExprStorage : public AffineExprStorage {
|
struct AffineBinaryOpExprStorage : public AffineExprStorage {
|
||||||
using KeyTy = std::pair<AffineExpr, AffineExpr>;
|
using KeyTy = std::tuple<unsigned, AffineExpr, AffineExpr>;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const {
|
bool operator==(const KeyTy &key) const {
|
||||||
return key.first == lhs && key.second == rhs;
|
return static_cast<AffineExprKind>(std::get<0>(key)) == kind &&
|
||||||
|
std::get<1>(key) == lhs && std::get<2>(key) == rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
static AffineBinaryOpExprStorage *
|
static AffineBinaryOpExprStorage *
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
|
auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
|
||||||
result->lhs = key.first;
|
result->kind = static_cast<AffineExprKind>(std::get<0>(key));
|
||||||
result->rhs = key.second;
|
result->lhs = std::get<1>(key);
|
||||||
|
result->rhs = std::get<2>(key);
|
||||||
result->context = result->lhs.getContext();
|
result->context = result->lhs.getContext();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -52,14 +55,18 @@ struct AffineBinaryOpExprStorage : public AffineExprStorage {
|
||||||
|
|
||||||
/// A dimensional or symbolic identifier appearing in an affine expression.
|
/// A dimensional or symbolic identifier appearing in an affine expression.
|
||||||
struct AffineDimExprStorage : public AffineExprStorage {
|
struct AffineDimExprStorage : public AffineExprStorage {
|
||||||
using KeyTy = unsigned;
|
using KeyTy = std::pair<unsigned, unsigned>;
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const { return position == key; }
|
bool operator==(const KeyTy &key) const {
|
||||||
|
return kind == static_cast<AffineExprKind>(key.first) &&
|
||||||
|
position == key.second;
|
||||||
|
}
|
||||||
|
|
||||||
static AffineDimExprStorage *
|
static AffineDimExprStorage *
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<AffineDimExprStorage>();
|
auto *result = allocator.allocate<AffineDimExprStorage>();
|
||||||
result->position = key;
|
result->kind = static_cast<AffineExprKind>(key.first);
|
||||||
|
result->position = key.second;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,6 +83,7 @@ struct AffineConstantExprStorage : public AffineExprStorage {
|
||||||
static AffineConstantExprStorage *
|
static AffineConstantExprStorage *
|
||||||
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
|
||||||
auto *result = allocator.allocate<AffineConstantExprStorage>();
|
auto *result = allocator.allocate<AffineConstantExprStorage>();
|
||||||
|
result->kind = AffineExprKind::Constant;
|
||||||
result->constant = key;
|
result->constant = key;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -271,7 +271,7 @@ private:
|
||||||
/// Mapping between attribute kind and a pair comprised of a base alias name
|
/// Mapping between attribute kind and a pair comprised of a base alias name
|
||||||
/// and a unique list of attributes belonging to this kind sorted by location
|
/// and a unique list of attributes belonging to this kind sorted by location
|
||||||
/// seen in the module.
|
/// seen in the module.
|
||||||
llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
|
llvm::MapVector<TypeID, std::pair<StringRef, std::vector<Attribute>>>
|
||||||
attrKindToAlias;
|
attrKindToAlias;
|
||||||
|
|
||||||
/// Set of types known to be used within the module.
|
/// Set of types known to be used within the module.
|
||||||
|
@ -301,13 +301,13 @@ void AliasState::initialize(
|
||||||
llvm::StringSet<> usedAliases;
|
llvm::StringSet<> usedAliases;
|
||||||
|
|
||||||
// Collect the set of aliases from each dialect.
|
// Collect the set of aliases from each dialect.
|
||||||
SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
|
SmallVector<std::pair<TypeID, StringRef>, 8> attributeKindAliases;
|
||||||
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
|
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
|
||||||
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
|
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
|
||||||
|
|
||||||
// AffineMap/Integer set have specific kind aliases.
|
// AffineMap/Integer set have specific kind aliases.
|
||||||
attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
|
attributeKindAliases.emplace_back(AffineMapAttr::getTypeID(), "map");
|
||||||
attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
|
attributeKindAliases.emplace_back(IntegerSetAttr::getTypeID(), "set");
|
||||||
|
|
||||||
for (auto &interface : interfaces) {
|
for (auto &interface : interfaces) {
|
||||||
interface.getAttributeKindAliases(attributeKindAliases);
|
interface.getAttributeKindAliases(attributeKindAliases);
|
||||||
|
@ -317,7 +317,7 @@ void AliasState::initialize(
|
||||||
|
|
||||||
// Setup the attribute kind aliases.
|
// Setup the attribute kind aliases.
|
||||||
StringRef alias;
|
StringRef alias;
|
||||||
unsigned attrKind;
|
TypeID attrKind;
|
||||||
for (auto &attrAliasPair : attributeKindAliases) {
|
for (auto &attrAliasPair : attributeKindAliases) {
|
||||||
std::tie(attrKind, alias) = attrAliasPair;
|
std::tie(attrKind, alias) = attrAliasPair;
|
||||||
assert(!alias.empty() && "expected non-empty alias string");
|
assert(!alias.empty() && "expected non-empty alias string");
|
||||||
|
@ -420,7 +420,7 @@ void AliasState::recordAttributeReference(Attribute attr) {
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// If this attribute kind has an alias, then record one for this attribute.
|
// If this attribute kind has an alias, then record one for this attribute.
|
||||||
auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
|
auto alias = attrKindToAlias.find(attr.getTypeID());
|
||||||
if (alias == attrKindToAlias.end())
|
if (alias == attrKindToAlias.end())
|
||||||
return;
|
return;
|
||||||
std::pair<StringRef, int> attrAlias(alias->second.first,
|
std::pair<StringRef, int> attrAlias(alias->second.first,
|
||||||
|
|
|
@ -57,7 +57,7 @@ Dialect &Attribute::getDialect() const {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
AffineMapAttr AffineMapAttr::get(AffineMap value) {
|
AffineMapAttr AffineMapAttr::get(AffineMap value) {
|
||||||
return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
|
return Base::get(value.getContext(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
|
AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
|
||||||
|
@ -67,7 +67,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
|
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
|
||||||
return Base::get(context, StandardAttributes::Array, value);
|
return Base::get(context, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
|
ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
|
||||||
|
@ -156,7 +156,7 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
|
||||||
if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
|
if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
|
||||||
value = storage;
|
value = storage;
|
||||||
|
|
||||||
return Base::get(context, StandardAttributes::Dictionary, value);
|
return Base::get(context, value);
|
||||||
}
|
}
|
||||||
/// Construct a dictionary with an array of values that is known to already be
|
/// Construct a dictionary with an array of values that is known to already be
|
||||||
/// sorted by name and uniqued.
|
/// sorted by name and uniqued.
|
||||||
|
@ -175,7 +175,7 @@ DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
|
||||||
return l.first == r.first;
|
return l.first == r.first;
|
||||||
}) == value.end() &&
|
}) == value.end() &&
|
||||||
"DictionaryAttr element names must be unique");
|
"DictionaryAttr element names must be unique");
|
||||||
return Base::get(context, StandardAttributes::Dictionary, value);
|
return Base::get(context, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
|
ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
|
||||||
|
@ -219,19 +219,19 @@ size_t DictionaryAttr::size() const { return getValue().size(); }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
FloatAttr FloatAttr::get(Type type, double value) {
|
FloatAttr FloatAttr::get(Type type, double value) {
|
||||||
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
|
return Base::get(type.getContext(), type, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
|
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
|
||||||
return Base::getChecked(loc, StandardAttributes::Float, type, value);
|
return Base::getChecked(loc, type, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
|
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
|
||||||
return Base::get(type.getContext(), StandardAttributes::Float, type, value);
|
return Base::get(type.getContext(), type, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
|
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
|
||||||
return Base::getChecked(loc, StandardAttributes::Float, type, value);
|
return Base::getChecked(loc, type, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
|
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
|
||||||
|
@ -279,14 +279,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
|
FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
|
||||||
return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
|
return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
|
||||||
.cast<FlatSymbolRefAttr>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SymbolRefAttr SymbolRefAttr::get(StringRef value,
|
SymbolRefAttr SymbolRefAttr::get(StringRef value,
|
||||||
ArrayRef<FlatSymbolRefAttr> nestedReferences,
|
ArrayRef<FlatSymbolRefAttr> nestedReferences,
|
||||||
MLIRContext *ctx) {
|
MLIRContext *ctx) {
|
||||||
return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
|
return Base::get(ctx, value, nestedReferences);
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
|
StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
|
||||||
|
@ -307,7 +306,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
|
||||||
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
|
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
|
||||||
if (type.isSignlessInteger(1))
|
if (type.isSignlessInteger(1))
|
||||||
return BoolAttr::get(value.getBoolValue(), type.getContext());
|
return BoolAttr::get(value.getBoolValue(), type.getContext());
|
||||||
return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
|
return Base::get(type.getContext(), type, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
|
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
|
||||||
|
@ -380,8 +379,7 @@ bool BoolAttr::classof(Attribute attr) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
|
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
|
||||||
return Base::get(value.getConstraint(0).getContext(),
|
return Base::get(value.getConstraint(0).getContext(), value);
|
||||||
StandardAttributes::IntegerSet, value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
||||||
|
@ -392,14 +390,12 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
||||||
|
|
||||||
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
|
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
|
return Base::get(context, dialect, attrData, type);
|
||||||
type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
|
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
|
||||||
Type type, Location location) {
|
Type type, Location location) {
|
||||||
return Base::getChecked(location, StandardAttributes::Opaque, dialect,
|
return Base::getChecked(location, dialect, attrData, type);
|
||||||
attrData, type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the dialect namespace of the opaque attribute.
|
/// Returns the dialect namespace of the opaque attribute.
|
||||||
|
@ -430,7 +426,7 @@ StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
|
||||||
|
|
||||||
/// Get an instance of a StringAttr with the given string and Type.
|
/// Get an instance of a StringAttr with the given string and Type.
|
||||||
StringAttr StringAttr::get(StringRef bytes, Type type) {
|
StringAttr StringAttr::get(StringRef bytes, Type type) {
|
||||||
return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
|
return Base::get(type.getContext(), bytes, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef StringAttr::getValue() const { return getImpl()->value; }
|
StringRef StringAttr::getValue() const { return getImpl()->value; }
|
||||||
|
@ -440,7 +436,7 @@ StringRef StringAttr::getValue() const { return getImpl()->value; }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
TypeAttr TypeAttr::get(Type value) {
|
TypeAttr TypeAttr::get(Type value) {
|
||||||
return Base::get(value.getContext(), StandardAttributes::Type, value);
|
return Base::get(value.getContext(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type TypeAttr::getValue() const { return getImpl()->value; }
|
Type TypeAttr::getValue() const { return getImpl()->value; }
|
||||||
|
@ -1036,8 +1032,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
|
||||||
|
|
||||||
DenseStringElementsAttr
|
DenseStringElementsAttr
|
||||||
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
|
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
|
||||||
return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
|
return Base::get(type.getContext(), type, values, (values.size() == 1));
|
||||||
type, values, (values.size() == 1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1088,8 +1083,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
|
||||||
assert((type.isa<RankedTensorType, VectorType>()) &&
|
assert((type.isa<RankedTensorType, VectorType>()) &&
|
||||||
"type must be ranked tensor or vector");
|
"type must be ranked tensor or vector");
|
||||||
assert(type.hasStaticShape() && "type must have static shape");
|
assert(type.hasStaticShape() && "type must have static shape");
|
||||||
return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
|
return Base::get(type.getContext(), type, data, isSplat);
|
||||||
type, data, isSplat);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Overload of the raw 'get' method that asserts that the given type is of
|
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||||
|
@ -1210,8 +1204,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
|
||||||
StringRef bytes) {
|
StringRef bytes) {
|
||||||
assert(TensorType::isValidElementType(type.getElementType()) &&
|
assert(TensorType::isValidElementType(type.getElementType()) &&
|
||||||
"Input element type should be a valid tensor element type");
|
"Input element type should be a valid tensor element type");
|
||||||
return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
|
return Base::get(type.getContext(), type, dialect, bytes);
|
||||||
dialect, bytes);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
|
StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
|
||||||
|
@ -1248,7 +1241,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
|
||||||
assert((type.isa<RankedTensorType, VectorType>()) &&
|
assert((type.isa<RankedTensorType, VectorType>()) &&
|
||||||
"type must be ranked tensor or vector");
|
"type must be ranked tensor or vector");
|
||||||
assert(type.hasStaticShape() && "type must have static shape");
|
assert(type.hasStaticShape() && "type must have static shape");
|
||||||
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
|
return Base::get(type.getContext(), type,
|
||||||
indices.cast<DenseIntElementsAttr>(), values);
|
indices.cast<DenseIntElementsAttr>(), values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,7 @@ bool LocationAttr::classof(Attribute attr) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Location CallSiteLoc::get(Location callee, Location caller) {
|
Location CallSiteLoc::get(Location callee, Location caller) {
|
||||||
return Base::get(callee->getContext(), StandardAttributes::CallSiteLocation,
|
return Base::get(callee->getContext(), callee, caller);
|
||||||
callee, caller);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Location CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
|
Location CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
|
||||||
|
@ -50,8 +49,7 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
|
||||||
|
|
||||||
Location FileLineColLoc::get(Identifier filename, unsigned line,
|
Location FileLineColLoc::get(Identifier filename, unsigned line,
|
||||||
unsigned column, MLIRContext *context) {
|
unsigned column, MLIRContext *context) {
|
||||||
return Base::get(context, StandardAttributes::FileLineColLocation, filename,
|
return Base::get(context, filename, line, column);
|
||||||
line, column);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
|
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
|
||||||
|
@ -95,7 +93,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
|
||||||
return UnknownLoc::get(context);
|
return UnknownLoc::get(context);
|
||||||
if (locs.size() == 1)
|
if (locs.size() == 1)
|
||||||
return locs.front();
|
return locs.front();
|
||||||
return Base::get(context, StandardAttributes::FusedLocation, locs, metadata);
|
return Base::get(context, locs, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<Location> FusedLoc::getLocations() const {
|
ArrayRef<Location> FusedLoc::getLocations() const {
|
||||||
|
@ -111,8 +109,7 @@ Attribute FusedLoc::getMetadata() const { return getImpl()->metadata; }
|
||||||
Location NameLoc::get(Identifier name, Location child) {
|
Location NameLoc::get(Identifier name, Location child) {
|
||||||
assert(!child.isa<NameLoc>() &&
|
assert(!child.isa<NameLoc>() &&
|
||||||
"a NameLoc cannot be used as a child of another NameLoc");
|
"a NameLoc cannot be used as a child of another NameLoc");
|
||||||
return Base::get(child->getContext(), StandardAttributes::NameLocation, name,
|
return Base::get(child->getContext(), name, child);
|
||||||
child);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Location NameLoc::get(Identifier name, MLIRContext *context) {
|
Location NameLoc::get(Identifier name, MLIRContext *context) {
|
||||||
|
@ -131,9 +128,8 @@ Location NameLoc::getChildLoc() const { return getImpl()->child; }
|
||||||
|
|
||||||
Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
|
Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
|
||||||
Location fallbackLocation) {
|
Location fallbackLocation) {
|
||||||
return Base::get(fallbackLocation->getContext(),
|
return Base::get(fallbackLocation->getContext(), underlyingLocation, typeID,
|
||||||
StandardAttributes::OpaqueLocation, underlyingLocation,
|
fallbackLocation);
|
||||||
typeID, fallbackLocation);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uintptr_t OpaqueLoc::getUnderlyingLocation() const {
|
uintptr_t OpaqueLoc::getUnderlyingLocation() const {
|
||||||
|
|
|
@ -87,6 +87,10 @@ namespace {
|
||||||
struct BuiltinDialect : public Dialect {
|
struct BuiltinDialect : public Dialect {
|
||||||
BuiltinDialect(MLIRContext *context)
|
BuiltinDialect(MLIRContext *context)
|
||||||
: Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
|
: Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
|
||||||
|
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
|
||||||
|
FunctionType, IndexType, IntegerType, MemRefType,
|
||||||
|
UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
|
||||||
|
TupleType, UnrankedTensorType, VectorType>();
|
||||||
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
|
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
|
||||||
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
|
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
|
||||||
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
|
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
|
||||||
|
@ -95,11 +99,6 @@ struct BuiltinDialect : public Dialect {
|
||||||
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
|
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
|
||||||
UnknownLoc>();
|
UnknownLoc>();
|
||||||
|
|
||||||
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
|
|
||||||
FunctionType, IndexType, IntegerType, MemRefType,
|
|
||||||
UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
|
|
||||||
TupleType, UnrankedTensorType, VectorType>();
|
|
||||||
|
|
||||||
// TODO: These operations should be moved to a different dialect when they
|
// TODO: These operations should be moved to a different dialect when they
|
||||||
// have been fully decoupled from the core.
|
// have been fully decoupled from the core.
|
||||||
addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
|
addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
|
||||||
|
@ -363,56 +362,50 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
|
||||||
|
|
||||||
//// Types.
|
//// Types.
|
||||||
/// Floating-point Types.
|
/// Floating-point Types.
|
||||||
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
|
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
|
||||||
impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
|
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
|
||||||
impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
|
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
|
||||||
impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
|
impl->f64Ty = TypeUniquer::get<Float64Type>(this);
|
||||||
/// Index Type.
|
/// Index Type.
|
||||||
impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
|
impl->indexTy = TypeUniquer::get<IndexType>(this);
|
||||||
/// Integer Types.
|
/// Integer Types.
|
||||||
impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1,
|
impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
|
||||||
IntegerType::Signless);
|
impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
|
||||||
impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8,
|
impl->int16Ty =
|
||||||
IntegerType::Signless);
|
TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
|
||||||
impl->int16Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
impl->int32Ty =
|
||||||
16, IntegerType::Signless);
|
TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
|
||||||
impl->int32Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
impl->int64Ty =
|
||||||
32, IntegerType::Signless);
|
TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
|
||||||
impl->int64Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
impl->int128Ty =
|
||||||
64, IntegerType::Signless);
|
TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
|
||||||
impl->int128Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
|
|
||||||
128, IntegerType::Signless);
|
|
||||||
/// None Type.
|
/// None Type.
|
||||||
impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
|
impl->noneType = TypeUniquer::get<NoneType>(this);
|
||||||
|
|
||||||
//// Attributes.
|
//// Attributes.
|
||||||
//// Note: These must be registered after the types as they may generate one
|
//// Note: These must be registered after the types as they may generate one
|
||||||
//// of the above types internally.
|
//// of the above types internally.
|
||||||
/// Bool Attributes.
|
/// Bool Attributes.
|
||||||
impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
|
impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
|
||||||
this, StandardAttributes::Integer, impl->int1Ty,
|
this, impl->int1Ty, APInt(/*numBits=*/1, false))
|
||||||
APInt(/*numBits=*/1, false))
|
|
||||||
.cast<BoolAttr>();
|
.cast<BoolAttr>();
|
||||||
impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
|
impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
|
||||||
this, StandardAttributes::Integer, impl->int1Ty,
|
this, impl->int1Ty, APInt(/*numBits=*/1, true))
|
||||||
APInt(/*numBits=*/1, true))
|
|
||||||
.cast<BoolAttr>();
|
.cast<BoolAttr>();
|
||||||
/// Unit Attribute.
|
/// Unit Attribute.
|
||||||
impl->unitAttr =
|
impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
|
||||||
AttributeUniquer::get<UnitAttr>(this, StandardAttributes::Unit);
|
|
||||||
/// Unknown Location Attribute.
|
/// Unknown Location Attribute.
|
||||||
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
|
impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
|
||||||
this, StandardAttributes::UnknownLocation);
|
|
||||||
/// The empty dictionary attribute.
|
/// The empty dictionary attribute.
|
||||||
impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
|
impl->emptyDictionaryAttr =
|
||||||
this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
|
AttributeUniquer::get<DictionaryAttr>(this, ArrayRef<NamedAttribute>());
|
||||||
|
|
||||||
// Register the affine storage objects with the uniquer.
|
// Register the affine storage objects with the uniquer.
|
||||||
impl->affineUniquer.registerStorageType(
|
impl->affineUniquer
|
||||||
TypeID::get<AffineBinaryOpExprStorage>());
|
.registerParametricStorageType<AffineBinaryOpExprStorage>();
|
||||||
impl->affineUniquer.registerStorageType(
|
impl->affineUniquer
|
||||||
TypeID::get<AffineConstantExprStorage>());
|
.registerParametricStorageType<AffineConstantExprStorage>();
|
||||||
impl->affineUniquer.registerStorageType(TypeID::get<AffineDimExprStorage>());
|
impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext::~MLIRContext() {}
|
MLIRContext::~MLIRContext() {}
|
||||||
|
@ -582,7 +575,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
||||||
AbstractType(std::move(typeInfo));
|
AbstractType(std::move(typeInfo));
|
||||||
if (!impl.registeredTypes.insert({typeID, newInfo}).second)
|
if (!impl.registeredTypes.insert({typeID, newInfo}).second)
|
||||||
llvm::report_fatal_error("Dialect Type already registered.");
|
llvm::report_fatal_error("Dialect Type already registered.");
|
||||||
impl.typeUniquer.registerStorageType(typeID);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
||||||
|
@ -592,7 +584,6 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
|
||||||
AbstractAttribute(std::move(attrInfo));
|
AbstractAttribute(std::move(attrInfo));
|
||||||
if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
|
if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
|
||||||
llvm::report_fatal_error("Dialect Attribute already registered.");
|
llvm::report_fatal_error("Dialect Attribute already registered.");
|
||||||
impl.attributeUniquer.registerStorageType(typeID);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the dialect that registered the attribute with the provided typeid.
|
/// Get the dialect that registered the attribute with the provided typeid.
|
||||||
|
@ -718,7 +709,7 @@ IntegerType IntegerType::get(unsigned width,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
if (auto cached = getCachedIntegerType(width, signedness, context))
|
if (auto cached = getCachedIntegerType(width, signedness, context))
|
||||||
return cached;
|
return cached;
|
||||||
return Base::get(context, StandardTypes::Integer, width, signedness);
|
return Base::get(context, width, signedness);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType IntegerType::getChecked(unsigned width, Location location) {
|
IntegerType IntegerType::getChecked(unsigned width, Location location) {
|
||||||
|
@ -731,12 +722,16 @@ IntegerType IntegerType::getChecked(unsigned width,
|
||||||
if (auto cached =
|
if (auto cached =
|
||||||
getCachedIntegerType(width, signedness, location->getContext()))
|
getCachedIntegerType(width, signedness, location->getContext()))
|
||||||
return cached;
|
return cached;
|
||||||
return Base::getChecked(location, StandardTypes::Integer, width, signedness);
|
return Base::getChecked(location, width, signedness);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get an instance of the NoneType.
|
/// Get an instance of the NoneType.
|
||||||
NoneType NoneType::get(MLIRContext *context) {
|
NoneType NoneType::get(MLIRContext *context) {
|
||||||
return context->getImpl().noneType;
|
if (NoneType cachedInst = context->getImpl().noneType)
|
||||||
|
return cachedInst;
|
||||||
|
// Note: May happen when initializing the singleton attributes of the builtin
|
||||||
|
// dialect.
|
||||||
|
return Base::get(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -102,12 +102,11 @@ unsigned Type::getIntOrFloatBitWidth() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ComplexType ComplexType::get(Type elementType) {
|
ComplexType ComplexType::get(Type elementType) {
|
||||||
return Base::get(elementType.getContext(), StandardTypes::Complex,
|
return Base::get(elementType.getContext(), elementType);
|
||||||
elementType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
||||||
return Base::getChecked(location, StandardTypes::Complex, elementType);
|
return Base::getChecked(location, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify the construction of an integer type.
|
/// Verify the construction of an integer type.
|
||||||
|
@ -265,13 +264,12 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
|
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
|
||||||
return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
|
return Base::get(elementType.getContext(), shape, elementType);
|
||||||
elementType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
|
return Base::getChecked(location, shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
|
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
|
||||||
|
@ -320,15 +318,13 @@ bool TensorType::isValidElementType(Type type) {
|
||||||
|
|
||||||
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
|
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
|
||||||
Type elementType) {
|
Type elementType) {
|
||||||
return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
|
return Base::get(elementType.getContext(), shape, elementType);
|
||||||
elementType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
||||||
Type elementType,
|
Type elementType,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, StandardTypes::RankedTensor, shape,
|
return Base::getChecked(location, shape, elementType);
|
||||||
elementType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult RankedTensorType::verifyConstructionInvariants(
|
LogicalResult RankedTensorType::verifyConstructionInvariants(
|
||||||
|
@ -349,13 +345,12 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||||
return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
|
return Base::get(elementType.getContext(), elementType);
|
||||||
elementType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
|
return Base::getChecked(location, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
@ -444,8 +439,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
|
||||||
cleanedAffineMapComposition.push_back(map);
|
cleanedAffineMapComposition.push_back(map);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Base::get(context, StandardTypes::MemRef, shape, elementType,
|
return Base::get(context, shape, elementType, cleanedAffineMapComposition,
|
||||||
cleanedAffineMapComposition, memorySpace);
|
memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
|
ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
|
||||||
|
@ -462,15 +457,13 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
|
||||||
|
|
||||||
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
||||||
unsigned memorySpace) {
|
unsigned memorySpace) {
|
||||||
return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
|
return Base::get(elementType.getContext(), elementType, memorySpace);
|
||||||
elementType, memorySpace);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
||||||
unsigned memorySpace,
|
unsigned memorySpace,
|
||||||
Location location) {
|
Location location) {
|
||||||
return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
|
return Base::getChecked(location, elementType, memorySpace);
|
||||||
memorySpace);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned UnrankedMemRefType::getMemorySpace() const {
|
unsigned UnrankedMemRefType::getMemorySpace() const {
|
||||||
|
@ -642,7 +635,7 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
|
||||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||||
/// arguments define a well-formed type.
|
/// arguments define a well-formed type.
|
||||||
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
|
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
|
||||||
return Base::get(context, StandardTypes::Tuple, elementTypes);
|
return Base::get(context, elementTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create an empty tuple type.
|
/// Get or create an empty tuple type.
|
||||||
|
|
|
@ -19,8 +19,6 @@ using namespace mlir::detail;
|
||||||
// Type
|
// Type
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
unsigned Type::getKind() const { return impl->getKind(); }
|
|
||||||
|
|
||||||
Dialect &Type::getDialect() const {
|
Dialect &Type::getDialect() const {
|
||||||
return impl->getAbstractType().getDialect();
|
return impl->getAbstractType().getDialect();
|
||||||
}
|
}
|
||||||
|
@ -33,7 +31,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
||||||
|
|
||||||
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
|
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
return Base::get(context, Type::Kind::Function, inputs, results);
|
return Base::get(context, inputs, results);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
|
unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
|
||||||
|
@ -54,12 +52,12 @@ ArrayRef<Type> FunctionType::getResults() const {
|
||||||
|
|
||||||
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
|
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
return Base::get(context, Type::Kind::Opaque, dialect, typeData);
|
return Base::get(context, dialect, typeData);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
|
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
|
||||||
MLIRContext *context, Location location) {
|
MLIRContext *context, Location location) {
|
||||||
return Base::getChecked(location, Kind::Opaque, dialect, typeData);
|
return Base::getChecked(location, dialect, typeData);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the dialect namespace of the opaque type.
|
/// Returns the dialect namespace of the opaque type.
|
||||||
|
|
|
@ -16,19 +16,17 @@ using namespace mlir;
|
||||||
using namespace mlir::detail;
|
using namespace mlir::detail;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// This class represents a uniquer for storage instances of a specific type. It
|
/// This class represents a uniquer for storage instances of a specific type
|
||||||
/// contains all of the necessary data to unique storage instances in a thread
|
/// that has parametric storage. It contains all of the necessary data to unique
|
||||||
/// safe way. This allows for the main uniquer to bucket each of the individual
|
/// storage instances in a thread safe way. This allows for the main uniquer to
|
||||||
/// sub-types removing the need to lock the main uniquer itself.
|
/// bucket each of the individual sub-types removing the need to lock the main
|
||||||
struct InstSpecificUniquer {
|
/// uniquer itself.
|
||||||
|
struct ParametricStorageUniquer {
|
||||||
using BaseStorage = StorageUniquer::BaseStorage;
|
using BaseStorage = StorageUniquer::BaseStorage;
|
||||||
using StorageAllocator = StorageUniquer::StorageAllocator;
|
using StorageAllocator = StorageUniquer::StorageAllocator;
|
||||||
|
|
||||||
/// A lookup key for derived instances of storage objects.
|
/// A lookup key for derived instances of storage objects.
|
||||||
struct LookupKey {
|
struct LookupKey {
|
||||||
/// The known derived kind for the storage.
|
|
||||||
unsigned kind;
|
|
||||||
|
|
||||||
/// The known hash value of the key.
|
/// The known hash value of the key.
|
||||||
unsigned hashValue;
|
unsigned hashValue;
|
||||||
|
|
||||||
|
@ -63,18 +61,14 @@ struct InstSpecificUniquer {
|
||||||
static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
|
static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
|
||||||
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
|
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
|
||||||
return false;
|
return false;
|
||||||
// If the lookup kind matches the kind of the storage, then invoke the
|
// Invoke the equality function on the lookup key.
|
||||||
// equality function on the lookup key.
|
return lhs.isEqual(rhs.storage);
|
||||||
return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Unique types with specific hashing or storage constraints.
|
/// The set containing the allocated storage instances.
|
||||||
using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
|
using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
|
||||||
StorageTypeSet complexInstances;
|
StorageTypeSet instances;
|
||||||
|
|
||||||
/// Instances of this storage object.
|
|
||||||
llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
|
|
||||||
|
|
||||||
/// Allocator to use when constructing derived instances.
|
/// Allocator to use when constructing derived instances.
|
||||||
StorageAllocator allocator;
|
StorageAllocator allocator;
|
||||||
|
@ -91,107 +85,79 @@ struct StorageUniquerImpl {
|
||||||
using BaseStorage = StorageUniquer::BaseStorage;
|
using BaseStorage = StorageUniquer::BaseStorage;
|
||||||
using StorageAllocator = StorageUniquer::StorageAllocator;
|
using StorageAllocator = StorageUniquer::StorageAllocator;
|
||||||
|
|
||||||
/// Get or create an instance of a complex derived type.
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Parametric Storage
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Get or create an instance of a parametric type.
|
||||||
BaseStorage *
|
BaseStorage *
|
||||||
getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
|
getOrCreate(TypeID id, unsigned hashValue,
|
||||||
function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||||
assert(instUniquers.count(id) && "creating unregistered storage instance");
|
assert(parametricUniquers.count(id) &&
|
||||||
InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
|
"creating unregistered storage instance");
|
||||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
|
||||||
|
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
|
||||||
if (!threadingIsEnabled)
|
if (!threadingIsEnabled)
|
||||||
return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
|
return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
|
||||||
|
|
||||||
// Check for an existing instance in read-only mode.
|
// Check for an existing instance in read-only mode.
|
||||||
{
|
{
|
||||||
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
|
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
|
||||||
auto it = storageUniquer.complexInstances.find_as(lookupKey);
|
auto it = storageUniquer.instances.find_as(lookupKey);
|
||||||
if (it != storageUniquer.complexInstances.end())
|
if (it != storageUniquer.instances.end())
|
||||||
return it->storage;
|
return it->storage;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Acquire a writer-lock so that we can safely create the new type instance.
|
// Acquire a writer-lock so that we can safely create the new type instance.
|
||||||
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
|
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
|
||||||
return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
|
return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
|
||||||
}
|
}
|
||||||
/// Get or create an instance of a complex derived type in an thread-unsafe
|
/// Get or create an instance of a complex derived type in an thread-unsafe
|
||||||
/// fashion.
|
/// fashion.
|
||||||
BaseStorage *
|
BaseStorage *
|
||||||
getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
|
getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer,
|
||||||
InstSpecificUniquer::LookupKey &lookupKey,
|
ParametricStorageUniquer::LookupKey &lookupKey,
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||||
auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
|
auto existing = storageUniquer.instances.insert_as({}, lookupKey);
|
||||||
if (!existing.second)
|
if (!existing.second)
|
||||||
return existing.first->storage;
|
return existing.first->storage;
|
||||||
|
|
||||||
// Otherwise, construct and initialize the derived storage for this type
|
// Otherwise, construct and initialize the derived storage for this type
|
||||||
// instance.
|
// instance.
|
||||||
BaseStorage *storage =
|
BaseStorage *storage = ctorFn(storageUniquer.allocator);
|
||||||
initializeStorage(kind, storageUniquer.allocator, ctorFn);
|
|
||||||
*existing.first =
|
*existing.first =
|
||||||
InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
|
ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage};
|
||||||
return storage;
|
return storage;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create an instance of a simple derived type.
|
/// Erase an instance of a parametric derived type.
|
||||||
BaseStorage *
|
void erase(TypeID id, unsigned hashValue,
|
||||||
getOrCreate(TypeID id, unsigned kind,
|
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
|
||||||
assert(instUniquers.count(id) && "creating unregistered storage instance");
|
|
||||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
|
||||||
if (!threadingIsEnabled)
|
|
||||||
return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
|
|
||||||
|
|
||||||
// Check for an existing instance in read-only mode.
|
|
||||||
{
|
|
||||||
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
|
|
||||||
auto it = storageUniquer.simpleInstances.find(kind);
|
|
||||||
if (it != storageUniquer.simpleInstances.end())
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Acquire a writer-lock so that we can safely create the new type instance.
|
|
||||||
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
|
|
||||||
return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
|
|
||||||
}
|
|
||||||
/// Get or create an instance of a simple derived type in an thread-unsafe
|
|
||||||
/// fashion.
|
|
||||||
BaseStorage *
|
|
||||||
getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
|
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
|
||||||
auto &result = storageUniquer.simpleInstances[kind];
|
|
||||||
if (result)
|
|
||||||
return result;
|
|
||||||
|
|
||||||
// Otherwise, create and return a new storage instance.
|
|
||||||
return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Erase an instance of a complex derived type.
|
|
||||||
void erase(TypeID id, unsigned kind, unsigned hashValue,
|
|
||||||
function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
function_ref<void(BaseStorage *)> cleanupFn) {
|
function_ref<void(BaseStorage *)> cleanupFn) {
|
||||||
assert(instUniquers.count(id) && "erasing unregistered storage instance");
|
assert(parametricUniquers.count(id) &&
|
||||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
"erasing unregistered storage instance");
|
||||||
InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
|
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
|
||||||
|
ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
|
||||||
|
|
||||||
// Acquire a writer-lock so that we can safely erase the type instance.
|
// Acquire a writer-lock so that we can safely erase the type instance.
|
||||||
llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
|
llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
|
||||||
auto existing = storageUniquer.complexInstances.find_as(lookupKey);
|
auto existing = storageUniquer.instances.find_as(lookupKey);
|
||||||
if (existing == storageUniquer.complexInstances.end())
|
if (existing == storageUniquer.instances.end())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// Cleanup the storage and remove it from the map.
|
// Cleanup the storage and remove it from the map.
|
||||||
cleanupFn(existing->storage);
|
cleanupFn(existing->storage);
|
||||||
storageUniquer.complexInstances.erase(existing);
|
storageUniquer.instances.erase(existing);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mutates an instance of a derived storage in a thread-safe way.
|
/// Mutates an instance of a derived storage in a thread-safe way.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
mutate(TypeID id,
|
mutate(TypeID id,
|
||||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||||
assert(instUniquers.count(id) && "mutating unregistered storage instance");
|
assert(parametricUniquers.count(id) &&
|
||||||
InstSpecificUniquer &storageUniquer = *instUniquers[id];
|
"mutating unregistered storage instance");
|
||||||
|
ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
|
||||||
if (!threadingIsEnabled)
|
if (!threadingIsEnabled)
|
||||||
return mutationFn(storageUniquer.allocator);
|
return mutationFn(storageUniquer.allocator);
|
||||||
|
|
||||||
|
@ -199,21 +165,31 @@ struct StorageUniquerImpl {
|
||||||
return mutationFn(storageUniquer.allocator);
|
return mutationFn(storageUniquer.allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Singleton Storage
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Get or create an instance of a singleton storage class.
|
||||||
|
BaseStorage *getSingleton(TypeID id) {
|
||||||
|
BaseStorage *singletonInstance = singletonInstances[id];
|
||||||
|
assert(singletonInstance && "expected singleton instance to exist");
|
||||||
|
return singletonInstance;
|
||||||
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Instance Storage
|
// Instance Storage
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Utility to create and initialize a storage instance.
|
|
||||||
BaseStorage *
|
|
||||||
initializeStorage(unsigned kind, StorageAllocator &allocator,
|
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
|
||||||
BaseStorage *storage = ctorFn(allocator);
|
|
||||||
storage->kind = kind;
|
|
||||||
return storage;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Map of type ids to the storage uniquer to use for registered objects.
|
/// Map of type ids to the storage uniquer to use for registered objects.
|
||||||
DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
|
DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
|
||||||
|
parametricUniquers;
|
||||||
|
|
||||||
|
/// Map of type ids to a singleton instance when the storage class is a
|
||||||
|
/// singleton.
|
||||||
|
DenseMap<TypeID, BaseStorage *> singletonInstances;
|
||||||
|
|
||||||
|
/// Allocator used for uniquing singleton instances.
|
||||||
|
StorageAllocator singletonAllocator;
|
||||||
|
|
||||||
/// Flag specifying if multi-threading is enabled within the uniquer.
|
/// Flag specifying if multi-threading is enabled within the uniquer.
|
||||||
bool threadingIsEnabled = true;
|
bool threadingIsEnabled = true;
|
||||||
|
@ -229,41 +205,47 @@ void StorageUniquer::disableMultithreading(bool disable) {
|
||||||
impl->threadingIsEnabled = !disable;
|
impl->threadingIsEnabled = !disable;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a new storage object with this uniquer using the given unique type
|
|
||||||
/// id.
|
|
||||||
void StorageUniquer::registerStorageType(TypeID id) {
|
|
||||||
impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Implementation for getting/creating an instance of a derived type with
|
/// Implementation for getting/creating an instance of a derived type with
|
||||||
/// complex storage.
|
/// parametric storage.
|
||||||
auto StorageUniquer::getImpl(
|
auto StorageUniquer::getParametricStorageTypeImpl(
|
||||||
const TypeID &id, unsigned kind, unsigned hashValue,
|
TypeID id, unsigned hashValue,
|
||||||
function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
|
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
|
||||||
return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
|
return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation for getting/creating an instance of a derived type with
|
/// Implementation for registering an instance of a derived type with
|
||||||
/// default storage.
|
/// parametric storage.
|
||||||
auto StorageUniquer::getImpl(
|
void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
|
||||||
const TypeID &id, unsigned kind,
|
impl->parametricUniquers.try_emplace(
|
||||||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
|
id, std::make_unique<ParametricStorageUniquer>());
|
||||||
return impl->getOrCreate(id, kind, ctorFn);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation for erasing an instance of a derived type with complex
|
/// Implementation for getting an instance of a derived type with default
|
||||||
/// storage.
|
/// storage.
|
||||||
void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
|
auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
|
||||||
unsigned hashValue,
|
return impl->getSingleton(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implementation for registering an instance of a derived type with default
|
||||||
|
/// storage.
|
||||||
|
void StorageUniquer::registerSingletonImpl(
|
||||||
|
TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
|
||||||
|
assert(!impl->singletonInstances.count(id) &&
|
||||||
|
"storage class already registered");
|
||||||
|
impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implementation for erasing an instance of a derived type with parametric
|
||||||
|
/// storage.
|
||||||
|
void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue,
|
||||||
function_ref<bool(const BaseStorage *)> isEqual,
|
function_ref<bool(const BaseStorage *)> isEqual,
|
||||||
function_ref<void(BaseStorage *)> cleanupFn) {
|
function_ref<void(BaseStorage *)> cleanupFn) {
|
||||||
impl->erase(id, kind, hashValue, isEqual, cleanupFn);
|
impl->erase(id, hashValue, isEqual, cleanupFn);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation for mutating an instance of a derived storage.
|
/// Implementation for mutating an instance of a derived storage.
|
||||||
LogicalResult StorageUniquer::mutateImpl(
|
LogicalResult StorageUniquer::mutateImpl(
|
||||||
const TypeID &id,
|
TypeID id, function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
|
||||||
return impl->mutate(id, mutationFn);
|
return impl->mutate(id, mutationFn);
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,7 +156,7 @@ static Type parseTestType(DialectAsmParser &parser,
|
||||||
StringRef name;
|
StringRef name;
|
||||||
if (parser.parseLess() || parser.parseKeyword(&name))
|
if (parser.parseLess() || parser.parseKeyword(&name))
|
||||||
return Type();
|
return Type();
|
||||||
auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
|
auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
|
||||||
|
|
||||||
// If this type already has been parsed above in the stack, expect just the
|
// If this type already has been parsed above in the stack, expect just the
|
||||||
// name.
|
// name.
|
||||||
|
|
|
@ -26,10 +26,6 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
|
||||||
TestTypeInterface::Trait> {
|
TestTypeInterface::Trait> {
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static TestType get(MLIRContext *context) {
|
|
||||||
return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Provide a definition for the necessary interface methods.
|
/// Provide a definition for the necessary interface methods.
|
||||||
void printTypeC(Location loc) const {
|
void printTypeC(Location loc) const {
|
||||||
emitRemark(loc) << *this << " - TestC";
|
emitRemark(loc) << *this << " - TestC";
|
||||||
|
@ -72,9 +68,8 @@ class TestRecursiveType
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
|
static TestRecursiveType get(MLIRContext *ctx, StringRef name) {
|
||||||
return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
|
return Base::get(ctx, name);
|
||||||
name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Body getter and setter.
|
/// Body getter and setter.
|
||||||
|
|
|
@ -41,7 +41,7 @@ struct TestRecursiveTypesPass
|
||||||
LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
|
LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
|
||||||
MLIRContext *ctx = &getContext();
|
MLIRContext *ctx = &getContext();
|
||||||
FuncOp func = getFunction();
|
FuncOp func = getFunction();
|
||||||
auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
|
auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
|
||||||
if (failed(type.setBody(type)))
|
if (failed(type.setBody(type)))
|
||||||
return func.emitError("expected to be able to set the type body");
|
return func.emitError("expected to be able to set the type body");
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
|
||||||
"not expected to be able to change function body more than once");
|
"not expected to be able to change function body more than once");
|
||||||
|
|
||||||
// Expecting to get the same type for the same name.
|
// Expecting to get the same type for the same name.
|
||||||
auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
|
auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
|
||||||
if (type != other)
|
if (type != other)
|
||||||
return func.emitError("expected type name to be the uniquing key");
|
return func.emitError("expected type name to be the uniquing key");
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue