From 6af866c58d21813fb243906611d02bb2a8ffa43a Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Fri, 29 Jun 2018 22:08:05 -0700 Subject: [PATCH] Enhance the type system to support arbitrary precision integers, which are important for low-bitwidth inference cases and hardware synthesis targets. Rename 'int' to 'affineint' to avoid confusion between "the integers" and "the int type". PiperOrigin-RevId: 202751508 --- mlir/include/mlir/IR/Types.h | 63 +++++++++++++++++---------------- mlir/lib/IR/MLIRContext.cpp | 47 +++++++++--------------- mlir/lib/IR/Types.cpp | 44 +++++++++++++++++++---- mlir/lib/Parser/Lexer.cpp | 10 ++++++ mlir/lib/Parser/Parser.cpp | 44 +++++++++-------------- mlir/lib/Parser/Token.cpp | 12 +++++++ mlir/lib/Parser/Token.h | 3 ++ mlir/lib/Parser/TokenKinds.def | 8 ++--- mlir/test/IR/parser-errors.mlir | 10 ++++-- mlir/test/IR/parser.mlir | 15 ++++---- 10 files changed, 146 insertions(+), 110 deletions(-) diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index a9ee0394d060..5f2ca9d3fb19 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -24,18 +24,12 @@ namespace mlir { class MLIRContext; class PrimitiveType; + class IntegerType; /// Integer identifier for all the concrete type kinds. enum class TypeKind { - // Integer. - I1, - I8, - I16, - I32, - I64, - // Target pointer sized integer. - Int, + AffineInt, // Floating point. BF16, @@ -48,6 +42,7 @@ enum class TypeKind { LAST_PRIMITIVE_TYPE = F64, // Derived types. + Integer, Function, Vector, RankedTensor, @@ -80,12 +75,8 @@ public: void dump() const; // Convenience factories. - static PrimitiveType *getI1(MLIRContext *ctx); - static PrimitiveType *getI8(MLIRContext *ctx); - static PrimitiveType *getI16(MLIRContext *ctx); - static PrimitiveType *getI32(MLIRContext *ctx); - static PrimitiveType *getI64(MLIRContext *ctx); - static PrimitiveType *getInt(MLIRContext *ctx); + static IntegerType *getInt(unsigned width, MLIRContext *ctx); + static PrimitiveType *getAffineInt(MLIRContext *ctx); static PrimitiveType *getBF16(MLIRContext *ctx); static PrimitiveType *getF16(MLIRContext *ctx); static PrimitiveType *getF32(MLIRContext *ctx); @@ -140,23 +131,9 @@ private: PrimitiveType(TypeKind kind, MLIRContext *context); }; -inline PrimitiveType *Type::getI1(MLIRContext *ctx) { - return PrimitiveType::get(TypeKind::I1, ctx); -} -inline PrimitiveType *Type::getI8(MLIRContext *ctx) { - return PrimitiveType::get(TypeKind::I8, ctx); -} -inline PrimitiveType *Type::getI16(MLIRContext *ctx) { - return PrimitiveType::get(TypeKind::I16, ctx); -} -inline PrimitiveType *Type::getI32(MLIRContext *ctx) { - return PrimitiveType::get(TypeKind::I32, ctx); -} -inline PrimitiveType *Type::getI64(MLIRContext *ctx) { - return PrimitiveType::get(TypeKind::I64, ctx); -} -inline PrimitiveType *Type::getInt(MLIRContext *ctx) { - return PrimitiveType::get(TypeKind::Int, ctx); + +inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) { + return PrimitiveType::get(TypeKind::AffineInt, ctx); } inline PrimitiveType *Type::getBF16(MLIRContext *ctx) { return PrimitiveType::get(TypeKind::BF16, ctx); @@ -171,6 +148,30 @@ inline PrimitiveType *Type::getF64(MLIRContext *ctx) { return PrimitiveType::get(TypeKind::F64, ctx); } +/// Integer types can have arbitrary bitwidth up to a large fixed limit of 4096. +class IntegerType : public Type { +public: + static IntegerType *get(unsigned width, MLIRContext *context); + + /// Return the bitwidth of this integer type. + unsigned getWidth() const { + return width; + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const Type *type) { + return type->getKind() == TypeKind::Integer; + } +private: + unsigned width; + IntegerType(unsigned width, MLIRContext *context); +}; + +inline IntegerType *Type::getInt(unsigned width, MLIRContext *ctx) { + return IntegerType::get(width, ctx); +} + + /// Function types map from a list of inputs to a list of results. class FunctionType : public Type { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 7c1112bc2300..85ff432c88b9 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -120,6 +120,9 @@ public: using AffineMapSet = DenseSet; AffineMapSet affineMaps; + /// Integer type uniquing. + DenseMap integers; + /// Function type uniquing. using FunctionTypeSet = DenseSet; FunctionTypeSet functions; @@ -173,15 +176,10 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) { return Identifier(it->getKeyData()); } - //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// -PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context) - : Type(kind, context) { -} - PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) { assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind"); auto &impl = context->getImpl(); @@ -200,10 +198,16 @@ PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) { return impl.primitives[(int)kind] = ptr; } -FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, - unsigned numResults, MLIRContext *context) - : Type(TypeKind::Function, context, numInputs), - numResults(numResults), inputsAndResults(inputsAndResults) { +IntegerType *IntegerType::get(unsigned width, MLIRContext *context) { + auto &impl = context->getImpl(); + + auto *&result = impl.integers[width]; + if (!result) { + result = impl.allocator.Allocate(); + new (result) IntegerType(width, context); + } + + return result; } FunctionType *FunctionType::get(ArrayRef inputs, ArrayRef results, @@ -236,18 +240,9 @@ FunctionType *FunctionType::get(ArrayRef inputs, ArrayRef results, return *existing.first = result; } - - -VectorType::VectorType(ArrayRef shape, PrimitiveType *elementType, - MLIRContext *context) - : Type(TypeKind::Vector, context, shape.size()), - shapeElements(shape.data()), elementType(elementType) { -} - - VectorType *VectorType::get(ArrayRef shape, Type *elementType) { assert(!shape.empty() && "vector types must have at least one dimension"); - assert(isa(elementType) && + assert((isa(elementType) || isa(elementType)) && "vectors elements must be primitives"); auto *context = elementType->getContext(); @@ -277,22 +272,12 @@ VectorType *VectorType::get(ArrayRef shape, Type *elementType) { TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context) : Type(kind, context), elementType(elementType) { - assert((isa(elementType) || isa(elementType)) && + assert((isa(elementType) || isa(elementType) || + isa(elementType)) && "tensor elements must be primitives or vectors"); assert(isa(this)); } -RankedTensorType::RankedTensorType(ArrayRef shape, Type *elementType, - MLIRContext *context) - : TensorType(TypeKind::RankedTensor, elementType, context), - shapeElements(shape.data()) { - setSubclassData(shape.size()); -} - -UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) - : TensorType(TypeKind::UnrankedTensor, elementType, context) { -} - RankedTensorType *RankedTensorType::get(ArrayRef shape, Type *elementType) { auto *context = elementType->getContext(); diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index a5578b875c79..e16c6eb0ccf2 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -21,18 +21,50 @@ #include "mlir/Support/STLExtras.h" using namespace mlir; +PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context) + : Type(kind, context) { +} + +IntegerType::IntegerType(unsigned width, MLIRContext *context) + : Type(TypeKind::Integer, context), width(width) { +} + +FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, + unsigned numResults, MLIRContext *context) + : Type(TypeKind::Function, context, numInputs), + numResults(numResults), inputsAndResults(inputsAndResults) { +} + +VectorType::VectorType(ArrayRef shape, PrimitiveType *elementType, + MLIRContext *context) + : Type(TypeKind::Vector, context, shape.size()), + shapeElements(shape.data()), elementType(elementType) { +} + +RankedTensorType::RankedTensorType(ArrayRef shape, Type *elementType, + MLIRContext *context) + : TensorType(TypeKind::RankedTensor, elementType, context), + shapeElements(shape.data()) { + setSubclassData(shape.size()); +} + +UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) + : TensorType(TypeKind::UnrankedTensor, elementType, context) { +} + void Type::print(raw_ostream &os) const { switch (getKind()) { - case TypeKind::I1: os << "i1"; return; - case TypeKind::I8: os << "i8"; return; - case TypeKind::I16: os << "i16"; return; - case TypeKind::I32: os << "i32"; return; - case TypeKind::I64: os << "i64"; return; - case TypeKind::Int: os << "int"; return; + case TypeKind::AffineInt: os << "affineint"; return; case TypeKind::BF16: os << "bf16"; return; case TypeKind::F16: os << "f16"; return; case TypeKind::F32: os << "f32"; return; case TypeKind::F64: os << "f64"; return; + + case TypeKind::Integer: { + auto *integer = cast(this); + os << 'i' << integer->getWidth(); + return; + } case TypeKind::Function: { auto *func = cast(this); os << '('; diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp index 17755e0291f9..89432003b2c0 100644 --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -140,6 +140,7 @@ Token Lexer::lexComment() { /// Lex a bare identifier or keyword that starts with a letter. /// /// bare-id ::= letter (letter|digit|[_])* +/// integer-type ::= `i[1-9][0-9]*` /// Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) { // Match the rest of the identifier regex: [0-9a-zA-Z_]* @@ -149,6 +150,15 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) { // Check to see if this identifier is a keyword. StringRef spelling(tokStart, curPtr-tokStart); + // Check for i123. + if (tokStart[0] == 'i') { + bool allDigits = true; + for (auto c : spelling.drop_front()) + allDigits &= isdigit(c) != 0; + if (allDigits && spelling.size() != 1) + return Token(Token::inttype, spelling); + } + Token::Kind kind = llvm::StringSwitch(spelling) #define TOK_KEYWORD(SPELLING) \ .Case(#SPELLING, Token::kw_##SPELLING) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 692705018aa8..1bfa331256dc 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -121,7 +121,7 @@ private: // as the results of their action. // Type parsing. - PrimitiveType *parsePrimitiveType(); + Type *parsePrimitiveType(); Type *parseElementType(); VectorType *parseVectorType(); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions); @@ -218,12 +218,11 @@ parseCommaSeparatedList(Token::Kind rightToken, /// Parse the low-level fixed dtypes in the system. /// -/// primitive-type -/// ::= `f16` | `bf16` | `f32` | `f64` // Floating point -/// | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers -/// | `int` +/// primitive-type ::= `f16` | `bf16` | `f32` | `f64` +/// primitive-type ::= integer-type +/// primitive-type ::= `affineint` /// -PrimitiveType *Parser::parsePrimitiveType() { +Type *Parser::parsePrimitiveType() { switch (curToken.getKind()) { default: return (emitError("expected type"), nullptr); @@ -239,24 +238,16 @@ PrimitiveType *Parser::parsePrimitiveType() { case Token::kw_f64: consumeToken(Token::kw_f64); return Type::getF64(context); - case Token::kw_i1: - consumeToken(Token::kw_i1); - return Type::getI1(context); - case Token::kw_i8: - consumeToken(Token::kw_i8); - return Type::getI8(context); - case Token::kw_i16: - consumeToken(Token::kw_i16); - return Type::getI16(context); - case Token::kw_i32: - consumeToken(Token::kw_i32); - return Type::getI32(context); - case Token::kw_i64: - consumeToken(Token::kw_i64); - return Type::getI64(context); - case Token::kw_int: - consumeToken(Token::kw_int); - return Type::getInt(context); + case Token::kw_affineint: + consumeToken(Token::kw_affineint); + return Type::getAffineInt(context); + case Token::inttype: { + auto width = curToken.getIntTypeBitwidth(); + if (!width.hasValue()) + return (emitError("invalid integer width"), nullptr); + consumeToken(Token::inttype); + return Type::getInt(width.getValue(), context); + } } } @@ -419,11 +410,9 @@ Type *Parser::parseMemRefType() { return (emitError("expected '>' in memref type"), nullptr); // FIXME: Add an IR representation for memref types. - return Type::getI1(context); + return Type::getInt(1, context); } - - /// Parse a function type. /// /// function-type ::= type-list-parens `->` type-list @@ -445,7 +434,6 @@ Type *Parser::parseFunctionType() { return FunctionType::get(arguments, results, context); } - /// Parse an arbitrary type. /// /// type ::= primitive-type diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp index 5563255b4444..e1e4bedf28ff 100644 --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -48,6 +48,18 @@ Optional Token::getUnsignedIntegerValue() const { return result; } +/// For an inttype token, return its bitwidth. +Optional Token::getIntTypeBitwidth() const { + unsigned result = 0; + if (spelling[1] == '0' || + spelling.drop_front().getAsInteger(10, result) || + // Arbitrary but large limit on bitwidth. + result > 4096 || result == 0) + return None; + return result; +} + + /// Given a 'string' token, return its value, including removing the quote /// characters and unescaping the contents of the string. std::string Token::getStringValue() const { diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h index e5e4fc41886e..bc9e8e4a694f 100644 --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -73,6 +73,9 @@ public: /// return None. Optional getUnsignedIntegerValue() const; + /// For an inttype token, return its bitwidth. + Optional getIntTypeBitwidth() const; + /// Given a 'string' token, return its value, including removing the quote /// characters and unescaping the contents of the string. std::string getStringValue() const; diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 72d769a90031..73a30df11f66 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -58,6 +58,7 @@ TOK_IDENTIFIER(affine_map_identifier) // #foo // Literals TOK_LITERAL(integer) // 42 TOK_LITERAL(string) // "foo" +TOK_LITERAL(inttype) // i421 // Punctuation. TOK_PUNCTUATION(arrow, "->") @@ -84,6 +85,7 @@ TOK_OPERATOR(floordiv, "floordiv") // TODO: More operator tokens // Keywords. These turn "foo" into Token::kw_foo enums. +TOK_KEYWORD(affineint) TOK_KEYWORD(bf16) TOK_KEYWORD(br) TOK_KEYWORD(cfgfunc) @@ -91,12 +93,6 @@ TOK_KEYWORD(extfunc) TOK_KEYWORD(f16) TOK_KEYWORD(f32) TOK_KEYWORD(f64) -TOK_KEYWORD(i1) -TOK_KEYWORD(i16) -TOK_KEYWORD(i32) -TOK_KEYWORD(i64) -TOK_KEYWORD(i8) -TOK_KEYWORD(int) TOK_KEYWORD(memref) TOK_KEYWORD(mlfunc) TOK_KEYWORD(return) diff --git a/mlir/test/IR/parser-errors.mlir b/mlir/test/IR/parser-errors.mlir index 408fe139965a..fb4c2cc2c43f 100644 --- a/mlir/test/IR/parser-errors.mlir +++ b/mlir/test/IR/parser-errors.mlir @@ -6,7 +6,7 @@ ; Check different error cases. ; ----- -extfunc @illegaltype(i42) ; expected-error {{expected type}} +extfunc @illegaltype(i) ; expected-error {{expected type}} ; ----- @@ -19,7 +19,7 @@ cfgfunc @bar() ; expected-error {{expected '{' in CFG function}} ; ----- -extfunc missingsigil() -> (i1, int, f32) ; expected-error {{expected a function identifier like}} +extfunc missingsigil() -> (i1, affineint, f32) ; expected-error {{expected a function identifier like}} ; ----- @@ -75,3 +75,9 @@ bb40: ""() ; expected-error {{empty operation name is invalid}} return } + +; ----- + +extfunc @illegaltype(i0) ; expected-error {{invalid integer width}} + + diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index d69192dc9bbd..1b9d9cbff1da 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -10,25 +10,28 @@ extfunc @foo(i32, i64) -> f32 ; CHECK: extfunc @bar() extfunc @bar() -> () -; CHECK: extfunc @baz() -> (i1, int, f32) -extfunc @baz() -> (i1, int, f32) +; CHECK: extfunc @baz() -> (i1, affineint, f32) +extfunc @baz() -> (i1, affineint, f32) ; CHECK: extfunc @missingReturn() extfunc @missingReturn() +; CHECK: extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19) +extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19) + ; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>) extfunc @vectors(vector<1 x f32>, vector<2x4xf32>) -; CHECK: extfunc @tensors(tensor, tensor>, tensor<1x?x4x?x?xint>, tensor) +; CHECK: extfunc @tensors(tensor, tensor>, tensor<1x?x4x?x?xaffineint>, tensor) extfunc @tensors(tensor, tensor>, - tensor<1x?x4x?x?xint>, tensor) + tensor<1x?x4x?x?xaffineint>, tensor) ; CHECK: extfunc @memrefs(i1, i1) -extfunc @memrefs(memref<1x?x4x?x?xint>, memref) +extfunc @memrefs(memref<1x?x4x?x?xaffineint>, memref) ; CHECK: extfunc @functions((i1, i1) -> (), () -> ()) -extfunc @functions((memref<1x?x4x?x?xint>, memref) -> (), ()->()) +extfunc @functions((memref<1x?x4x?x?xaffineint>, memref) -> (), ()->()) ; CHECK-LABEL: cfgfunc @simpleCFG() {