Clean up the implementation of Type, making it structurally more similar to

Instruction and AffineExpr.  NFC.

PiperOrigin-RevId: 203287117
This commit is contained in:
Chris Lattner 2018-07-04 09:13:39 -07:00 committed by jpienaar
parent bd7c1f9566
commit ad4ea23278
3 changed files with 67 additions and 71 deletions

View File

@ -26,8 +26,13 @@ namespace mlir {
class PrimitiveType;
class IntegerType;
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
/// MLIRContext. As such, they are passed around by raw non-const pointer.
///
class Type {
public:
/// Integer identifier for all the concrete type kinds.
enum class TypeKind {
enum class Kind {
// Target pointer sized integer.
AffineInt,
@ -51,22 +56,12 @@ enum class TypeKind {
// TODO: MemRef types.
};
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
/// MLIRContext. As such, they are passed around by raw non-const pointer.
///
class Type {
public:
/// Return the classification for this type.
TypeKind getKind() const {
Kind getKind() const {
return kind;
}
/// Return true if this type is the specified kind.
bool is(TypeKind k) const {
return kind == k;
}
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const { return context; }
@ -83,10 +78,10 @@ public:
static PrimitiveType *getF64(MLIRContext *ctx);
protected:
explicit Type(TypeKind kind, MLIRContext *context)
explicit Type(Kind kind, MLIRContext *context)
: context(context), kind(kind), subclassData(0) {
}
explicit Type(TypeKind kind, MLIRContext *context, unsigned subClassData)
explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
: Type(kind, context) {
setSubclassData(subClassData);
}
@ -102,11 +97,13 @@ protected:
}
private:
Type(const Type&) = delete;
void operator=(const Type&) = delete;
/// This refers to the MLIRContext in which this type was uniqued.
MLIRContext *const context;
/// Classification of the subclass, used for type checking.
TypeKind kind : 8;
Kind kind : 8;
// Space for subclasses to store data.
unsigned subclassData : 24;
@ -121,31 +118,31 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
/// and floating point values.
class PrimitiveType : public Type {
public:
static PrimitiveType *get(TypeKind kind, MLIRContext *context);
static PrimitiveType *get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() <= TypeKind::LAST_PRIMITIVE_TYPE;
return type->getKind() <= Kind::LAST_PRIMITIVE_TYPE;
}
private:
PrimitiveType(TypeKind kind, MLIRContext *context);
PrimitiveType(Kind kind, MLIRContext *context);
};
inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) {
return PrimitiveType::get(TypeKind::AffineInt, ctx);
return PrimitiveType::get(Kind::AffineInt, ctx);
}
inline PrimitiveType *Type::getBF16(MLIRContext *ctx) {
return PrimitiveType::get(TypeKind::BF16, ctx);
return PrimitiveType::get(Kind::BF16, ctx);
}
inline PrimitiveType *Type::getF16(MLIRContext *ctx) {
return PrimitiveType::get(TypeKind::F16, ctx);
return PrimitiveType::get(Kind::F16, ctx);
}
inline PrimitiveType *Type::getF32(MLIRContext *ctx) {
return PrimitiveType::get(TypeKind::F32, ctx);
return PrimitiveType::get(Kind::F32, ctx);
}
inline PrimitiveType *Type::getF64(MLIRContext *ctx) {
return PrimitiveType::get(TypeKind::F64, ctx);
return PrimitiveType::get(Kind::F64, ctx);
}
/// Integer types can have arbitrary bitwidth up to a large fixed limit of 4096.
@ -160,7 +157,7 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == TypeKind::Integer;
return type->getKind() == Kind::Integer;
}
private:
unsigned width;
@ -189,7 +186,7 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == TypeKind::Function;
return type->getKind() == Kind::Function;
}
private:
@ -217,7 +214,7 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == TypeKind::Vector;
return type->getKind() == Kind::Vector;
}
private:
@ -236,15 +233,15 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == TypeKind::RankedTensor ||
type->getKind() == TypeKind::UnrankedTensor;
return type->getKind() == Kind::RankedTensor ||
type->getKind() == Kind::UnrankedTensor;
}
protected:
/// The type of each scalar element of the tensor.
Type *elementType;
TensorType(TypeKind kind, Type *elementType, MLIRContext *context);
TensorType(Kind kind, Type *elementType, MLIRContext *context);
};
/// Ranked tensor types represent multi-dimensional arrays that have a shape
@ -262,7 +259,7 @@ public:
unsigned getRank() const { return getShape().size(); }
static bool classof(const Type *type) {
return type->getKind() == TypeKind::RankedTensor;
return type->getKind() == Kind::RankedTensor;
}
private:
@ -279,7 +276,7 @@ public:
static UnrankedTensorType *get(Type *elementType);
static bool classof(const Type *type) {
return type->getKind() == TypeKind::UnrankedTensor;
return type->getKind() == Kind::UnrankedTensor;
}
private:

View File

@ -121,7 +121,7 @@ public:
llvm::StringMap<char, llvm::BumpPtrAllocator&> identifiers;
// Primitive type uniquing.
PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
PrimitiveType *primitives[int(Type::Kind::LAST_PRIMITIVE_TYPE)+1] = {nullptr};
// Affine map uniquing.
using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
@ -194,8 +194,8 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) {
// Types
//===----------------------------------------------------------------------===//
PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
PrimitiveType *PrimitiveType::get(Kind kind, MLIRContext *context) {
assert(kind <= Kind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
auto &impl = context->getImpl();
// We normally have these types.
@ -284,7 +284,7 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
}
TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: Type(kind, context), elementType(elementType) {
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType) ||
isa<IntegerType>(elementType)) &&

View File

@ -16,56 +16,55 @@
// =============================================================================
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/STLExtras.h"
using namespace mlir;
PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
PrimitiveType::PrimitiveType(Kind kind, MLIRContext *context)
: Type(kind, context) {
}
IntegerType::IntegerType(unsigned width, MLIRContext *context)
: Type(TypeKind::Integer, context), width(width) {
: Type(Kind::Integer, context), width(width) {
}
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
unsigned numResults, MLIRContext *context)
: Type(TypeKind::Function, context, numInputs),
: Type(Kind::Function, context, numInputs),
numResults(numResults), inputsAndResults(inputsAndResults) {
}
VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
MLIRContext *context)
: Type(TypeKind::Vector, context, shape.size()),
: Type(Kind::Vector, context, shape.size()),
shapeElements(shape.data()), elementType(elementType) {
}
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: TensorType(TypeKind::RankedTensor, elementType, context),
: TensorType(Kind::RankedTensor, elementType, context),
shapeElements(shape.data()) {
setSubclassData(shape.size());
}
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
: TensorType(Kind::UnrankedTensor, elementType, context) {
}
void Type::print(raw_ostream &os) const {
switch (getKind()) {
case TypeKind::AffineInt: os << "affineint"; return;
case TypeKind::BF16: os << "bf16"; return;
case TypeKind::F16: os << "f16"; return;
case TypeKind::F32: os << "f32"; return;
case TypeKind::F64: os << "f64"; return;
case Kind::AffineInt: os << "affineint"; return;
case Kind::BF16: os << "bf16"; return;
case Kind::F16: os << "f16"; return;
case Kind::F32: os << "f32"; return;
case Kind::F64: os << "f64"; return;
case TypeKind::Integer: {
case Kind::Integer: {
auto *integer = cast<IntegerType>(this);
os << 'i' << integer->getWidth();
return;
}
case TypeKind::Function: {
case Kind::Function: {
auto *func = cast<FunctionType>(this);
os << '(';
interleave(func->getInputs(),
@ -84,7 +83,7 @@ void Type::print(raw_ostream &os) const {
}
return;
}
case TypeKind::Vector: {
case Kind::Vector: {
auto *v = cast<VectorType>(this);
os << "vector<";
for (auto dim : v->getShape())
@ -92,7 +91,7 @@ void Type::print(raw_ostream &os) const {
os << *v->getElementType() << '>';
return;
}
case TypeKind::RankedTensor: {
case Kind::RankedTensor: {
auto *v = cast<RankedTensorType>(this);
os << "tensor<";
for (auto dim : v->getShape()) {
@ -105,7 +104,7 @@ void Type::print(raw_ostream &os) const {
os << *v->getElementType() << '>';
return;
}
case TypeKind::UnrankedTensor: {
case Kind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(this);
os << "tensor<??" << *v->getElementType() << '>';
return;