Eliminate "primitive" types from being a thing, splitting them into FloatType

and OtherType.  Other type is now the thing that holds AffineInt, Control,
eventually Resource, Variant, String, etc.  FloatType holds the floating point
types, and allows convenient query of isa<FloatType>().

This fixes issues where we allowed control to be the element type of tensor,
memref, vector.  At the same time, ban AffineInt from being an element of a
vector/memref/tensor as well since we don't need it.

I updated the spec to match this as well.

PiperOrigin-RevId: 206361942
This commit is contained in:
Chris Lattner 2018-07-27 13:09:58 -07:00 committed by jpienaar
parent 6e89270b2d
commit c77f39f55c
8 changed files with 200 additions and 158 deletions

View File

@ -57,12 +57,13 @@ public:
Module *createModule();
// Types.
PrimitiveType *getAffineIntType();
PrimitiveType *getBF16Type();
PrimitiveType *getF16Type();
PrimitiveType *getF32Type();
PrimitiveType *getF64Type();
PrimitiveType *getTFControlType();
FloatType *getBF16Type();
FloatType *getF16Type();
FloatType *getF32Type();
FloatType *getF64Type();
OtherType *getAffineIntType();
OtherType *getTFControlType();
IntegerType *getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results);

View File

@ -24,8 +24,9 @@
namespace mlir {
class AffineMap;
class MLIRContext;
class PrimitiveType;
class IntegerType;
class FloatType;
class OtherType;
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
/// MLIRContext. As such, they are passed around by raw non-const pointer.
@ -34,21 +35,23 @@ class Type {
public:
/// Integer identifier for all the concrete type kinds.
enum class Kind {
// Target pointer sized integer.
// Target pointer sized integer, used (e.g.) in affine mappings.
AffineInt,
// TensorFlow types.
TFControl,
/// These are marker for the first and last 'other' type.
FIRST_OTHER_TYPE = AffineInt,
LAST_OTHER_TYPE = TFControl,
// Floating point.
BF16,
F16,
F32,
F64,
// TensorFlow types.
TFControl,
/// This is a marker for the last primitive type. The range of primitive
/// types is expected to be this element and earlier.
LAST_PRIMITIVE_TYPE = TFControl,
FIRST_FLOATING_POINT_TYPE = BF16,
LAST_FLOATING_POINT_TYPE = F64,
// Derived types.
Integer,
@ -67,9 +70,10 @@ public:
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const { return context; }
// Convenience predicates. This is only for primitive types, derived types
// should use isa/dyn_cast.
// Convenience predicates. This is only for 'other' and floating point types,
// derived types should use isa/dyn_cast.
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
bool isTFControl() const { return getKind() == Kind::TFControl; }
bool isBF16() const { return getKind() == Kind::BF16; }
bool isF16() const { return getKind() == Kind::F16; }
bool isF32() const { return getKind() == Kind::F32; }
@ -77,12 +81,12 @@ public:
// Convenience factories.
static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
static PrimitiveType *getAffineInt(MLIRContext *ctx);
static PrimitiveType *getBF16(MLIRContext *ctx);
static PrimitiveType *getF16(MLIRContext *ctx);
static PrimitiveType *getF32(MLIRContext *ctx);
static PrimitiveType *getF64(MLIRContext *ctx);
static PrimitiveType *getTFControl(MLIRContext *ctx);
static FloatType *getBF16(MLIRContext *ctx);
static FloatType *getF16(MLIRContext *ctx);
static FloatType *getF32(MLIRContext *ctx);
static FloatType *getF64(MLIRContext *ctx);
static OtherType *getAffineInt(MLIRContext *ctx);
static OtherType *getTFControl(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
@ -124,40 +128,6 @@ inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
return os;
}
/// Primitive types are the atomic base of the type system, including affine
/// integer and floating point values.
class PrimitiveType : public Type {
public:
static PrimitiveType *get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() <= Kind::LAST_PRIMITIVE_TYPE;
}
private:
PrimitiveType(Kind kind, MLIRContext *context);
~PrimitiveType() = delete;
};
inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) {
return PrimitiveType::get(Kind::AffineInt, ctx);
}
inline PrimitiveType *Type::getBF16(MLIRContext *ctx) {
return PrimitiveType::get(Kind::BF16, ctx);
}
inline PrimitiveType *Type::getF16(MLIRContext *ctx) {
return PrimitiveType::get(Kind::F16, ctx);
}
inline PrimitiveType *Type::getF32(MLIRContext *ctx) {
return PrimitiveType::get(Kind::F32, ctx);
}
inline PrimitiveType *Type::getF64(MLIRContext *ctx) {
return PrimitiveType::get(Kind::F64, ctx);
}
inline PrimitiveType *Type::getTFControl(MLIRContext *ctx) {
return PrimitiveType::get(Kind::TFControl, ctx);
}
/// Integer types can have arbitrary bitwidth up to a large fixed limit of 4096.
class IntegerType : public Type {
public:
@ -182,6 +152,56 @@ inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx);
}
class FloatType : public Type {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE &&
type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE;
}
static FloatType *get(Kind kind, MLIRContext *context);
private:
FloatType(Kind kind, MLIRContext *context);
~FloatType() = delete;
};
inline FloatType *Type::getBF16(MLIRContext *ctx) {
return FloatType::get(Kind::BF16, ctx);
}
inline FloatType *Type::getF16(MLIRContext *ctx) {
return FloatType::get(Kind::F16, ctx);
}
inline FloatType *Type::getF32(MLIRContext *ctx) {
return FloatType::get(Kind::F32, ctx);
}
inline FloatType *Type::getF64(MLIRContext *ctx) {
return FloatType::get(Kind::F64, ctx);
}
/// This is a type for the random collection of special base types.
class OtherType : public Type {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() >= Kind::FIRST_OTHER_TYPE &&
type->getKind() <= Kind::LAST_OTHER_TYPE;
}
static OtherType *get(Kind kind, MLIRContext *context);
private:
OtherType(Kind kind, MLIRContext *context);
~OtherType() = delete;
};
inline OtherType *Type::getAffineInt(MLIRContext *ctx) {
return OtherType::get(Kind::AffineInt, ctx);
}
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx);
}
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
public:
@ -210,7 +230,6 @@ private:
~FunctionType() = delete;
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension.
class VectorType : public Type {

View File

@ -35,21 +35,17 @@ Module *Builder::createModule() { return new Module(context); }
// Types.
//===----------------------------------------------------------------------===//
PrimitiveType *Builder::getAffineIntType() {
return Type::getAffineInt(context);
}
FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
PrimitiveType *Builder::getBF16Type() { return Type::getBF16(context); }
FloatType *Builder::getF16Type() { return Type::getF16(context); }
PrimitiveType *Builder::getF16Type() { return Type::getF16(context); }
FloatType *Builder::getF32Type() { return Type::getF32(context); }
PrimitiveType *Builder::getF32Type() { return Type::getF32(context); }
FloatType *Builder::getF64Type() { return Type::getF64(context); }
PrimitiveType *Builder::getF64Type() { return Type::getF64(context); }
OtherType *Builder::getAffineIntType() { return Type::getAffineInt(context); }
PrimitiveType *Builder::getTFControlType() {
return Type::getTFControl(context);
}
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
IntegerType *Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);

View File

@ -187,8 +187,13 @@ public:
/// These are identifiers uniqued into this MLIRContext.
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
// Primitive type uniquing.
PrimitiveType *primitives[int(Type::Kind::LAST_PRIMITIVE_TYPE) + 1] = {
// Uniquing table for 'other' types.
OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
// Uniquing table for 'float' types.
FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
nullptr};
// Affine map uniquing.
@ -293,24 +298,6 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) {
// Type uniquing
//===----------------------------------------------------------------------===//
PrimitiveType *PrimitiveType::get(Kind kind, MLIRContext *context) {
assert(kind <= Kind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
auto &impl = context->getImpl();
// We normally have these types.
if (impl.primitives[(int)kind])
return impl.primitives[(int)kind];
// On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<PrimitiveType>();
// Initialize the memory using placement new.
new (ptr) PrimitiveType(kind, context);
// Cache and return it.
return impl.primitives[(int)kind] = ptr;
}
IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
auto &impl = context->getImpl();
@ -323,6 +310,47 @@ IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
return result;
}
FloatType *FloatType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
auto &impl = context->getImpl();
// We normally have these types.
auto *&entry =
impl.floatTypes[(int)kind - int(Kind::FIRST_FLOATING_POINT_TYPE)];
if (entry)
return entry;
// On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<FloatType>();
// Initialize the memory using placement new.
new (ptr) FloatType(kind, context);
// Cache and return it.
return entry = ptr;
}
OtherType *OtherType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
"Not an 'other' type kind");
auto &impl = context->getImpl();
// We normally have these types.
auto *&entry = impl.otherTypes[(int)kind - int(Kind::FIRST_OTHER_TYPE)];
if (entry)
return entry;
// On the first use, we allocate them into the bump pointer.
auto *ptr = impl.allocator.Allocate<OtherType>();
// Initialize the memory using placement new.
new (ptr) OtherType(kind, context);
// Cache and return it.
return entry = ptr;
}
FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
ArrayRef<Type *> results,
MLIRContext *context) {
@ -356,7 +384,7 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
assert(!shape.empty() && "vector types must have at least one dimension");
assert((isa<PrimitiveType>(elementType) || isa<IntegerType>(elementType)) &&
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
"vectors elements must be primitives");
auto *context = elementType->getContext();
@ -385,7 +413,7 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: Type(kind, context), elementType(elementType) {
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType) ||
assert((isa<FloatType>(elementType) || isa<VectorType>(elementType) ||
isa<IntegerType>(elementType)) &&
"tensor elements must be primitives or vectors");
assert(isa<TensorType>(this));

View File

@ -21,13 +21,14 @@
#include "mlir/Support/STLExtras.h"
using namespace mlir;
PrimitiveType::PrimitiveType(Kind kind, MLIRContext *context)
: Type(kind, context) {}
IntegerType::IntegerType(unsigned width, MLIRContext *context)
: Type(Kind::Integer, context), width(width) {
}
FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
unsigned numResults, MLIRContext *context)
: Type(Kind::Function, context, numInputs),

View File

@ -161,8 +161,6 @@ public:
// as the results of their action.
// Type parsing.
Type *parsePrimitiveType();
Type *parseElementType();
VectorType *parseVectorType();
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
Type *parseTensorType();
@ -256,16 +254,41 @@ ParseResult Parser::parseCommaSeparatedListUntil(
// Type Parsing
//===----------------------------------------------------------------------===//
/// Parse the low-level fixed dtypes in the system.
/// Parse an arbitrary type.
///
/// primitive-type ::= `f16` | `bf16` | `f32` | `f64`
/// primitive-type ::= integer-type
/// primitive-type ::= `affineint`
/// type ::= integer-type
/// | float-type
/// | other-type
/// | vector-type
/// | tensor-type
/// | memref-type
/// | function-type
///
Type *Parser::parsePrimitiveType() {
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// other-type ::= `affineint` | `tf_control`
///
Type *Parser::parseType() {
switch (getToken().getKind()) {
default:
return (emitError("expected type"), nullptr);
case Token::kw_memref:
return parseMemRefType();
case Token::kw_tensor:
return parseTensorType();
case Token::kw_vector:
return parseVectorType();
case Token::l_paren:
return parseFunctionType();
// integer-type
case Token::inttype: {
auto width = getToken().getIntTypeBitwidth();
if (!width.hasValue())
return (emitError("invalid integer width"), nullptr);
consumeToken(Token::inttype);
return builder.getIntegerType(width.getValue());
}
// float-type
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
@ -278,31 +301,15 @@ Type *Parser::parsePrimitiveType() {
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getF64Type();
// other-type
case Token::kw_affineint:
consumeToken(Token::kw_affineint);
return builder.getAffineIntType();
case Token::kw_tf_control:
consumeToken(Token::kw_tf_control);
return builder.getTFControlType();
case Token::inttype: {
auto width = getToken().getIntTypeBitwidth();
if (!width.hasValue())
return (emitError("invalid integer width"), nullptr);
consumeToken(Token::inttype);
return builder.getIntegerType(width.getValue());
}
}
}
/// Parse the element type of a tensor or memref type.
///
/// element-type ::= primitive-type | vector-type
///
Type *Parser::parseElementType() {
if (getToken().is(Token::kw_vector))
return parseVectorType();
return parsePrimitiveType();
}
/// Parse a vector type.
@ -343,12 +350,13 @@ VectorType *Parser::parseVectorType() {
}
// Parse the element type.
auto *elementType = parsePrimitiveType();
if (!elementType)
auto typeLoc = getToken().getLoc();
auto *elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
return (emitError(typeLoc, "invalid vector element type"), nullptr);
return VectorType::get(dimensions, elementType);
}
@ -411,12 +419,14 @@ Type *Parser::parseTensorType() {
}
// Parse the element type.
auto elementType = parseElementType();
if (!elementType)
auto typeLoc = getToken().getLoc();
auto *elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
if (parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
!isa<VectorType>(elementType))
return (emitError(typeLoc, "invalid tensor element type"), nullptr);
if (isUnranked)
return builder.getTensorType(elementType);
@ -442,10 +452,15 @@ Type *Parser::parseMemRefType() {
return nullptr;
// Parse the element type.
auto elementType = parseElementType();
auto typeLoc = getToken().getLoc();
auto *elementType = parseType();
if (!elementType)
return nullptr;
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
!isa<VectorType>(elementType))
return (emitError(typeLoc, "invalid memref element type"), nullptr);
// Parse semi-affine-map-composition.
SmallVector<AffineMap *, 2> affineMapComposition;
unsigned memorySpace = 0;
@ -506,29 +521,6 @@ Type *Parser::parseFunctionType() {
return builder.getFunctionType(arguments, results);
}
/// Parse an arbitrary type.
///
/// type ::= primitive-type
/// | vector-type
/// | tensor-type
/// | memref-type
/// | function-type
/// element-type ::= primitive-type | vector-type
///
Type *Parser::parseType() {
switch (getToken().getKind()) {
case Token::kw_memref:
return parseMemRefType();
case Token::kw_tensor:
return parseTensorType();
case Token::kw_vector:
return parseVectorType();
case Token::l_paren:
return parseFunctionType();
default:
return parsePrimitiveType();
}
}
/// Parse a list of types without an enclosing parenthesis. The list must have
/// at least one member.

View File

@ -10,7 +10,7 @@ extfunc @illegaltype(i) // expected-error {{expected type}}
// -----
extfunc @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{expected type}}
extfunc @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor element type}}
// -----
// Test no map in memref type.
@ -240,7 +240,12 @@ bb1(%x: i17):
// Test no nested vector.
extfunc @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
// expected-error@-1 {{expected type}}
// expected-error@-1 {{invalid vector element type}}
// -----
// affineint is not allowed in a vector.
extfunc @vectors(vector<1 x affineint>) // expected-error {{invalid vector element type}}
// -----

View File

@ -37,12 +37,12 @@ extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19)
// CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
// CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xaffineint>, tensor<i8>)
// CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
tensor<1x?x4x?x?xaffineint>, tensor<i8>)
tensor<1x?x4x?x?xi32>, tensor<i8>)
// CHECK: extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map{{[0-9]+}}>, memref<i8, #map{{[0-9]+}}>)
extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map0>, memref<i8, #map1>)
// CHECK: extfunc @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<i8, #map{{[0-9]+}}>)
extfunc @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<i8, #map1>)
// Test memref affine map compositions.
@ -63,8 +63,8 @@ extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
// CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, #map0>, memref<i8, #map1>) -> (), () -> ())
extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())
// CHECK: extfunc @functions((memref<1x?x4x?x?xi32, #map0>, memref<i8, #map1>) -> (), () -> ())
extfunc @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) -> i1 {
cfgfunc @simpleCFG(i32, f32) -> i1 {