forked from OSchip/llvm-project
Clean up the implementation of Type, making it structurally more similar to
Instruction and AffineExpr. NFC. PiperOrigin-RevId: 203287117
This commit is contained in:
parent
bd7c1f9566
commit
ad4ea23278
|
@ -26,8 +26,13 @@ namespace mlir {
|
|||
class PrimitiveType;
|
||||
class IntegerType;
|
||||
|
||||
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
||||
/// MLIRContext. As such, they are passed around by raw non-const pointer.
|
||||
///
|
||||
class Type {
|
||||
public:
|
||||
/// Integer identifier for all the concrete type kinds.
|
||||
enum class TypeKind {
|
||||
enum class Kind {
|
||||
// Target pointer sized integer.
|
||||
AffineInt,
|
||||
|
||||
|
@ -51,22 +56,12 @@ enum class TypeKind {
|
|||
// TODO: MemRef types.
|
||||
};
|
||||
|
||||
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
||||
/// MLIRContext. As such, they are passed around by raw non-const pointer.
|
||||
///
|
||||
class Type {
|
||||
public:
|
||||
|
||||
/// Return the classification for this type.
|
||||
TypeKind getKind() const {
|
||||
Kind getKind() const {
|
||||
return kind;
|
||||
}
|
||||
|
||||
/// Return true if this type is the specified kind.
|
||||
bool is(TypeKind k) const {
|
||||
return kind == k;
|
||||
}
|
||||
|
||||
/// Return the LLVMContext in which this type was uniqued.
|
||||
MLIRContext *getContext() const { return context; }
|
||||
|
||||
|
@ -83,10 +78,10 @@ public:
|
|||
static PrimitiveType *getF64(MLIRContext *ctx);
|
||||
|
||||
protected:
|
||||
explicit Type(TypeKind kind, MLIRContext *context)
|
||||
explicit Type(Kind kind, MLIRContext *context)
|
||||
: context(context), kind(kind), subclassData(0) {
|
||||
}
|
||||
explicit Type(TypeKind kind, MLIRContext *context, unsigned subClassData)
|
||||
explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
|
||||
: Type(kind, context) {
|
||||
setSubclassData(subClassData);
|
||||
}
|
||||
|
@ -102,11 +97,13 @@ protected:
|
|||
}
|
||||
|
||||
private:
|
||||
Type(const Type&) = delete;
|
||||
void operator=(const Type&) = delete;
|
||||
/// This refers to the MLIRContext in which this type was uniqued.
|
||||
MLIRContext *const context;
|
||||
|
||||
/// Classification of the subclass, used for type checking.
|
||||
TypeKind kind : 8;
|
||||
Kind kind : 8;
|
||||
|
||||
// Space for subclasses to store data.
|
||||
unsigned subclassData : 24;
|
||||
|
@ -121,31 +118,31 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
|
|||
/// and floating point values.
|
||||
class PrimitiveType : public Type {
|
||||
public:
|
||||
static PrimitiveType *get(TypeKind kind, MLIRContext *context);
|
||||
static PrimitiveType *get(Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() <= TypeKind::LAST_PRIMITIVE_TYPE;
|
||||
return type->getKind() <= Kind::LAST_PRIMITIVE_TYPE;
|
||||
}
|
||||
private:
|
||||
PrimitiveType(TypeKind kind, MLIRContext *context);
|
||||
PrimitiveType(Kind kind, MLIRContext *context);
|
||||
};
|
||||
|
||||
|
||||
inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::AffineInt, ctx);
|
||||
return PrimitiveType::get(Kind::AffineInt, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getBF16(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::BF16, ctx);
|
||||
return PrimitiveType::get(Kind::BF16, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getF16(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::F16, ctx);
|
||||
return PrimitiveType::get(Kind::F16, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getF32(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::F32, ctx);
|
||||
return PrimitiveType::get(Kind::F32, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getF64(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::F64, ctx);
|
||||
return PrimitiveType::get(Kind::F64, ctx);
|
||||
}
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit of 4096.
|
||||
|
@ -160,7 +157,7 @@ public:
|
|||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::Integer;
|
||||
return type->getKind() == Kind::Integer;
|
||||
}
|
||||
private:
|
||||
unsigned width;
|
||||
|
@ -189,7 +186,7 @@ public:
|
|||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::Function;
|
||||
return type->getKind() == Kind::Function;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -217,7 +214,7 @@ public:
|
|||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::Vector;
|
||||
return type->getKind() == Kind::Vector;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -236,15 +233,15 @@ public:
|
|||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::RankedTensor ||
|
||||
type->getKind() == TypeKind::UnrankedTensor;
|
||||
return type->getKind() == Kind::RankedTensor ||
|
||||
type->getKind() == Kind::UnrankedTensor;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// The type of each scalar element of the tensor.
|
||||
Type *elementType;
|
||||
|
||||
TensorType(TypeKind kind, Type *elementType, MLIRContext *context);
|
||||
TensorType(Kind kind, Type *elementType, MLIRContext *context);
|
||||
};
|
||||
|
||||
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||
|
@ -262,7 +259,7 @@ public:
|
|||
unsigned getRank() const { return getShape().size(); }
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::RankedTensor;
|
||||
return type->getKind() == Kind::RankedTensor;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -279,7 +276,7 @@ public:
|
|||
static UnrankedTensorType *get(Type *elementType);
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::UnrankedTensor;
|
||||
return type->getKind() == Kind::UnrankedTensor;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -121,7 +121,7 @@ public:
|
|||
llvm::StringMap<char, llvm::BumpPtrAllocator&> identifiers;
|
||||
|
||||
// Primitive type uniquing.
|
||||
PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
|
||||
PrimitiveType *primitives[int(Type::Kind::LAST_PRIMITIVE_TYPE)+1] = {nullptr};
|
||||
|
||||
// Affine map uniquing.
|
||||
using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
|
||||
|
@ -194,8 +194,8 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) {
|
|||
// Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
|
||||
assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
|
||||
PrimitiveType *PrimitiveType::get(Kind kind, MLIRContext *context) {
|
||||
assert(kind <= Kind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// We normally have these types.
|
||||
|
@ -284,7 +284,7 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
|||
}
|
||||
|
||||
|
||||
TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
|
||||
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
|
||||
: Type(kind, context), elementType(elementType) {
|
||||
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType) ||
|
||||
isa<IntegerType>(elementType)) &&
|
||||
|
|
|
@ -16,56 +16,55 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
using namespace mlir;
|
||||
|
||||
PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
|
||||
PrimitiveType::PrimitiveType(Kind kind, MLIRContext *context)
|
||||
: Type(kind, context) {
|
||||
}
|
||||
|
||||
IntegerType::IntegerType(unsigned width, MLIRContext *context)
|
||||
: Type(TypeKind::Integer, context), width(width) {
|
||||
: Type(Kind::Integer, context), width(width) {
|
||||
}
|
||||
|
||||
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
||||
unsigned numResults, MLIRContext *context)
|
||||
: Type(TypeKind::Function, context, numInputs),
|
||||
: Type(Kind::Function, context, numInputs),
|
||||
numResults(numResults), inputsAndResults(inputsAndResults) {
|
||||
}
|
||||
|
||||
VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
|
||||
MLIRContext *context)
|
||||
: Type(TypeKind::Vector, context, shape.size()),
|
||||
: Type(Kind::Vector, context, shape.size()),
|
||||
shapeElements(shape.data()), elementType(elementType) {
|
||||
}
|
||||
|
||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: TensorType(TypeKind::RankedTensor, elementType, context),
|
||||
: TensorType(Kind::RankedTensor, elementType, context),
|
||||
shapeElements(shape.data()) {
|
||||
setSubclassData(shape.size());
|
||||
}
|
||||
|
||||
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
||||
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
|
||||
: TensorType(Kind::UnrankedTensor, elementType, context) {
|
||||
}
|
||||
|
||||
void Type::print(raw_ostream &os) const {
|
||||
switch (getKind()) {
|
||||
case TypeKind::AffineInt: os << "affineint"; return;
|
||||
case TypeKind::BF16: os << "bf16"; return;
|
||||
case TypeKind::F16: os << "f16"; return;
|
||||
case TypeKind::F32: os << "f32"; return;
|
||||
case TypeKind::F64: os << "f64"; return;
|
||||
case Kind::AffineInt: os << "affineint"; return;
|
||||
case Kind::BF16: os << "bf16"; return;
|
||||
case Kind::F16: os << "f16"; return;
|
||||
case Kind::F32: os << "f32"; return;
|
||||
case Kind::F64: os << "f64"; return;
|
||||
|
||||
case TypeKind::Integer: {
|
||||
case Kind::Integer: {
|
||||
auto *integer = cast<IntegerType>(this);
|
||||
os << 'i' << integer->getWidth();
|
||||
return;
|
||||
}
|
||||
case TypeKind::Function: {
|
||||
case Kind::Function: {
|
||||
auto *func = cast<FunctionType>(this);
|
||||
os << '(';
|
||||
interleave(func->getInputs(),
|
||||
|
@ -84,7 +83,7 @@ void Type::print(raw_ostream &os) const {
|
|||
}
|
||||
return;
|
||||
}
|
||||
case TypeKind::Vector: {
|
||||
case Kind::Vector: {
|
||||
auto *v = cast<VectorType>(this);
|
||||
os << "vector<";
|
||||
for (auto dim : v->getShape())
|
||||
|
@ -92,7 +91,7 @@ void Type::print(raw_ostream &os) const {
|
|||
os << *v->getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case TypeKind::RankedTensor: {
|
||||
case Kind::RankedTensor: {
|
||||
auto *v = cast<RankedTensorType>(this);
|
||||
os << "tensor<";
|
||||
for (auto dim : v->getShape()) {
|
||||
|
@ -105,7 +104,7 @@ void Type::print(raw_ostream &os) const {
|
|||
os << *v->getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case TypeKind::UnrankedTensor: {
|
||||
case Kind::UnrankedTensor: {
|
||||
auto *v = cast<UnrankedTensorType>(this);
|
||||
os << "tensor<??" << *v->getElementType() << '>';
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue