forked from OSchip/llvm-project
Enhance the type system to support arbitrary precision integers, which are
important for low-bitwidth inference cases and hardware synthesis targets. Rename 'int' to 'affineint' to avoid confusion between "the integers" and "the int type". PiperOrigin-RevId: 202751508
This commit is contained in:
parent
fdf7bc4e25
commit
6af866c58d
|
@ -24,18 +24,12 @@
|
|||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class PrimitiveType;
|
||||
class IntegerType;
|
||||
|
||||
/// Integer identifier for all the concrete type kinds.
|
||||
enum class TypeKind {
|
||||
// Integer.
|
||||
I1,
|
||||
I8,
|
||||
I16,
|
||||
I32,
|
||||
I64,
|
||||
|
||||
// Target pointer sized integer.
|
||||
Int,
|
||||
AffineInt,
|
||||
|
||||
// Floating point.
|
||||
BF16,
|
||||
|
@ -48,6 +42,7 @@ enum class TypeKind {
|
|||
LAST_PRIMITIVE_TYPE = F64,
|
||||
|
||||
// Derived types.
|
||||
Integer,
|
||||
Function,
|
||||
Vector,
|
||||
RankedTensor,
|
||||
|
@ -80,12 +75,8 @@ public:
|
|||
void dump() const;
|
||||
|
||||
// Convenience factories.
|
||||
static PrimitiveType *getI1(MLIRContext *ctx);
|
||||
static PrimitiveType *getI8(MLIRContext *ctx);
|
||||
static PrimitiveType *getI16(MLIRContext *ctx);
|
||||
static PrimitiveType *getI32(MLIRContext *ctx);
|
||||
static PrimitiveType *getI64(MLIRContext *ctx);
|
||||
static PrimitiveType *getInt(MLIRContext *ctx);
|
||||
static IntegerType *getInt(unsigned width, MLIRContext *ctx);
|
||||
static PrimitiveType *getAffineInt(MLIRContext *ctx);
|
||||
static PrimitiveType *getBF16(MLIRContext *ctx);
|
||||
static PrimitiveType *getF16(MLIRContext *ctx);
|
||||
static PrimitiveType *getF32(MLIRContext *ctx);
|
||||
|
@ -140,23 +131,9 @@ private:
|
|||
PrimitiveType(TypeKind kind, MLIRContext *context);
|
||||
};
|
||||
|
||||
inline PrimitiveType *Type::getI1(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::I1, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getI8(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::I8, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getI16(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::I16, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getI32(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::I32, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getI64(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::I64, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getInt(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::Int, ctx);
|
||||
|
||||
inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::AffineInt, ctx);
|
||||
}
|
||||
inline PrimitiveType *Type::getBF16(MLIRContext *ctx) {
|
||||
return PrimitiveType::get(TypeKind::BF16, ctx);
|
||||
|
@ -171,6 +148,30 @@ inline PrimitiveType *Type::getF64(MLIRContext *ctx) {
|
|||
return PrimitiveType::get(TypeKind::F64, ctx);
|
||||
}
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit of 4096.
|
||||
class IntegerType : public Type {
|
||||
public:
|
||||
static IntegerType *get(unsigned width, MLIRContext *context);
|
||||
|
||||
/// Return the bitwidth of this integer type.
|
||||
unsigned getWidth() const {
|
||||
return width;
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::Integer;
|
||||
}
|
||||
private:
|
||||
unsigned width;
|
||||
IntegerType(unsigned width, MLIRContext *context);
|
||||
};
|
||||
|
||||
inline IntegerType *Type::getInt(unsigned width, MLIRContext *ctx) {
|
||||
return IntegerType::get(width, ctx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Function types map from a list of inputs to a list of results.
|
||||
class FunctionType : public Type {
|
||||
|
|
|
@ -120,6 +120,9 @@ public:
|
|||
using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
|
||||
AffineMapSet affineMaps;
|
||||
|
||||
/// Integer type uniquing.
|
||||
DenseMap<unsigned, IntegerType*> integers;
|
||||
|
||||
/// Function type uniquing.
|
||||
using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
|
||||
FunctionTypeSet functions;
|
||||
|
@ -173,15 +176,10 @@ Identifier Identifier::get(StringRef str, const MLIRContext *context) {
|
|||
return Identifier(it->getKeyData());
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
|
||||
: Type(kind, context) {
|
||||
}
|
||||
|
||||
PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
|
||||
assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
|
||||
auto &impl = context->getImpl();
|
||||
|
@ -200,10 +198,16 @@ PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
|
|||
return impl.primitives[(int)kind] = ptr;
|
||||
}
|
||||
|
||||
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
||||
unsigned numResults, MLIRContext *context)
|
||||
: Type(TypeKind::Function, context, numInputs),
|
||||
numResults(numResults), inputsAndResults(inputsAndResults) {
|
||||
IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
auto *&result = impl.integers[width];
|
||||
if (!result) {
|
||||
result = impl.allocator.Allocate<IntegerType>();
|
||||
new (result) IntegerType(width, context);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
|
||||
|
@ -236,18 +240,9 @@ FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
|
||||
MLIRContext *context)
|
||||
: Type(TypeKind::Vector, context, shape.size()),
|
||||
shapeElements(shape.data()), elementType(elementType) {
|
||||
}
|
||||
|
||||
|
||||
VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
||||
assert(!shape.empty() && "vector types must have at least one dimension");
|
||||
assert(isa<PrimitiveType>(elementType) &&
|
||||
assert((isa<PrimitiveType>(elementType) || isa<IntegerType>(elementType)) &&
|
||||
"vectors elements must be primitives");
|
||||
|
||||
auto *context = elementType->getContext();
|
||||
|
@ -277,22 +272,12 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
|||
|
||||
TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
|
||||
: Type(kind, context), elementType(elementType) {
|
||||
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
|
||||
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType) ||
|
||||
isa<IntegerType>(elementType)) &&
|
||||
"tensor elements must be primitives or vectors");
|
||||
assert(isa<TensorType>(this));
|
||||
}
|
||||
|
||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: TensorType(TypeKind::RankedTensor, elementType, context),
|
||||
shapeElements(shape.data()) {
|
||||
setSubclassData(shape.size());
|
||||
}
|
||||
|
||||
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
||||
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
|
||||
}
|
||||
|
||||
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
||||
Type *elementType) {
|
||||
auto *context = elementType->getContext();
|
||||
|
|
|
@ -21,18 +21,50 @@
|
|||
#include "mlir/Support/STLExtras.h"
|
||||
using namespace mlir;
|
||||
|
||||
PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
|
||||
: Type(kind, context) {
|
||||
}
|
||||
|
||||
IntegerType::IntegerType(unsigned width, MLIRContext *context)
|
||||
: Type(TypeKind::Integer, context), width(width) {
|
||||
}
|
||||
|
||||
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
|
||||
unsigned numResults, MLIRContext *context)
|
||||
: Type(TypeKind::Function, context, numInputs),
|
||||
numResults(numResults), inputsAndResults(inputsAndResults) {
|
||||
}
|
||||
|
||||
VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
|
||||
MLIRContext *context)
|
||||
: Type(TypeKind::Vector, context, shape.size()),
|
||||
shapeElements(shape.data()), elementType(elementType) {
|
||||
}
|
||||
|
||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: TensorType(TypeKind::RankedTensor, elementType, context),
|
||||
shapeElements(shape.data()) {
|
||||
setSubclassData(shape.size());
|
||||
}
|
||||
|
||||
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
||||
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
|
||||
}
|
||||
|
||||
void Type::print(raw_ostream &os) const {
|
||||
switch (getKind()) {
|
||||
case TypeKind::I1: os << "i1"; return;
|
||||
case TypeKind::I8: os << "i8"; return;
|
||||
case TypeKind::I16: os << "i16"; return;
|
||||
case TypeKind::I32: os << "i32"; return;
|
||||
case TypeKind::I64: os << "i64"; return;
|
||||
case TypeKind::Int: os << "int"; return;
|
||||
case TypeKind::AffineInt: os << "affineint"; return;
|
||||
case TypeKind::BF16: os << "bf16"; return;
|
||||
case TypeKind::F16: os << "f16"; return;
|
||||
case TypeKind::F32: os << "f32"; return;
|
||||
case TypeKind::F64: os << "f64"; return;
|
||||
|
||||
case TypeKind::Integer: {
|
||||
auto *integer = cast<IntegerType>(this);
|
||||
os << 'i' << integer->getWidth();
|
||||
return;
|
||||
}
|
||||
case TypeKind::Function: {
|
||||
auto *func = cast<FunctionType>(this);
|
||||
os << '(';
|
||||
|
|
|
@ -140,6 +140,7 @@ Token Lexer::lexComment() {
|
|||
/// Lex a bare identifier or keyword that starts with a letter.
|
||||
///
|
||||
/// bare-id ::= letter (letter|digit|[_])*
|
||||
/// integer-type ::= `i[1-9][0-9]*`
|
||||
///
|
||||
Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
|
||||
// Match the rest of the identifier regex: [0-9a-zA-Z_]*
|
||||
|
@ -149,6 +150,15 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
|
|||
// Check to see if this identifier is a keyword.
|
||||
StringRef spelling(tokStart, curPtr-tokStart);
|
||||
|
||||
// Check for i123.
|
||||
if (tokStart[0] == 'i') {
|
||||
bool allDigits = true;
|
||||
for (auto c : spelling.drop_front())
|
||||
allDigits &= isdigit(c) != 0;
|
||||
if (allDigits && spelling.size() != 1)
|
||||
return Token(Token::inttype, spelling);
|
||||
}
|
||||
|
||||
Token::Kind kind = llvm::StringSwitch<Token::Kind>(spelling)
|
||||
#define TOK_KEYWORD(SPELLING) \
|
||||
.Case(#SPELLING, Token::kw_##SPELLING)
|
||||
|
|
|
@ -121,7 +121,7 @@ private:
|
|||
// as the results of their action.
|
||||
|
||||
// Type parsing.
|
||||
PrimitiveType *parsePrimitiveType();
|
||||
Type *parsePrimitiveType();
|
||||
Type *parseElementType();
|
||||
VectorType *parseVectorType();
|
||||
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
|
||||
|
@ -218,12 +218,11 @@ parseCommaSeparatedList(Token::Kind rightToken,
|
|||
|
||||
/// Parse the low-level fixed dtypes in the system.
|
||||
///
|
||||
/// primitive-type
|
||||
/// ::= `f16` | `bf16` | `f32` | `f64` // Floating point
|
||||
/// | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
|
||||
/// | `int`
|
||||
/// primitive-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||
/// primitive-type ::= integer-type
|
||||
/// primitive-type ::= `affineint`
|
||||
///
|
||||
PrimitiveType *Parser::parsePrimitiveType() {
|
||||
Type *Parser::parsePrimitiveType() {
|
||||
switch (curToken.getKind()) {
|
||||
default:
|
||||
return (emitError("expected type"), nullptr);
|
||||
|
@ -239,24 +238,16 @@ PrimitiveType *Parser::parsePrimitiveType() {
|
|||
case Token::kw_f64:
|
||||
consumeToken(Token::kw_f64);
|
||||
return Type::getF64(context);
|
||||
case Token::kw_i1:
|
||||
consumeToken(Token::kw_i1);
|
||||
return Type::getI1(context);
|
||||
case Token::kw_i8:
|
||||
consumeToken(Token::kw_i8);
|
||||
return Type::getI8(context);
|
||||
case Token::kw_i16:
|
||||
consumeToken(Token::kw_i16);
|
||||
return Type::getI16(context);
|
||||
case Token::kw_i32:
|
||||
consumeToken(Token::kw_i32);
|
||||
return Type::getI32(context);
|
||||
case Token::kw_i64:
|
||||
consumeToken(Token::kw_i64);
|
||||
return Type::getI64(context);
|
||||
case Token::kw_int:
|
||||
consumeToken(Token::kw_int);
|
||||
return Type::getInt(context);
|
||||
case Token::kw_affineint:
|
||||
consumeToken(Token::kw_affineint);
|
||||
return Type::getAffineInt(context);
|
||||
case Token::inttype: {
|
||||
auto width = curToken.getIntTypeBitwidth();
|
||||
if (!width.hasValue())
|
||||
return (emitError("invalid integer width"), nullptr);
|
||||
consumeToken(Token::inttype);
|
||||
return Type::getInt(width.getValue(), context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -419,11 +410,9 @@ Type *Parser::parseMemRefType() {
|
|||
return (emitError("expected '>' in memref type"), nullptr);
|
||||
|
||||
// FIXME: Add an IR representation for memref types.
|
||||
return Type::getI1(context);
|
||||
return Type::getInt(1, context);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Parse a function type.
|
||||
///
|
||||
/// function-type ::= type-list-parens `->` type-list
|
||||
|
@ -445,7 +434,6 @@ Type *Parser::parseFunctionType() {
|
|||
return FunctionType::get(arguments, results, context);
|
||||
}
|
||||
|
||||
|
||||
/// Parse an arbitrary type.
|
||||
///
|
||||
/// type ::= primitive-type
|
||||
|
|
|
@ -48,6 +48,18 @@ Optional<unsigned> Token::getUnsignedIntegerValue() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/// For an inttype token, return its bitwidth.
|
||||
Optional<unsigned> Token::getIntTypeBitwidth() const {
|
||||
unsigned result = 0;
|
||||
if (spelling[1] == '0' ||
|
||||
spelling.drop_front().getAsInteger(10, result) ||
|
||||
// Arbitrary but large limit on bitwidth.
|
||||
result > 4096 || result == 0)
|
||||
return None;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/// Given a 'string' token, return its value, including removing the quote
|
||||
/// characters and unescaping the contents of the string.
|
||||
std::string Token::getStringValue() const {
|
||||
|
|
|
@ -73,6 +73,9 @@ public:
|
|||
/// return None.
|
||||
Optional<unsigned> getUnsignedIntegerValue() const;
|
||||
|
||||
/// For an inttype token, return its bitwidth.
|
||||
Optional<unsigned> getIntTypeBitwidth() const;
|
||||
|
||||
/// Given a 'string' token, return its value, including removing the quote
|
||||
/// characters and unescaping the contents of the string.
|
||||
std::string getStringValue() const;
|
||||
|
|
|
@ -58,6 +58,7 @@ TOK_IDENTIFIER(affine_map_identifier) // #foo
|
|||
// Literals
|
||||
TOK_LITERAL(integer) // 42
|
||||
TOK_LITERAL(string) // "foo"
|
||||
TOK_LITERAL(inttype) // i421
|
||||
|
||||
// Punctuation.
|
||||
TOK_PUNCTUATION(arrow, "->")
|
||||
|
@ -84,6 +85,7 @@ TOK_OPERATOR(floordiv, "floordiv")
|
|||
// TODO: More operator tokens
|
||||
|
||||
// Keywords. These turn "foo" into Token::kw_foo enums.
|
||||
TOK_KEYWORD(affineint)
|
||||
TOK_KEYWORD(bf16)
|
||||
TOK_KEYWORD(br)
|
||||
TOK_KEYWORD(cfgfunc)
|
||||
|
@ -91,12 +93,6 @@ TOK_KEYWORD(extfunc)
|
|||
TOK_KEYWORD(f16)
|
||||
TOK_KEYWORD(f32)
|
||||
TOK_KEYWORD(f64)
|
||||
TOK_KEYWORD(i1)
|
||||
TOK_KEYWORD(i16)
|
||||
TOK_KEYWORD(i32)
|
||||
TOK_KEYWORD(i64)
|
||||
TOK_KEYWORD(i8)
|
||||
TOK_KEYWORD(int)
|
||||
TOK_KEYWORD(memref)
|
||||
TOK_KEYWORD(mlfunc)
|
||||
TOK_KEYWORD(return)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
; Check different error cases.
|
||||
; -----
|
||||
|
||||
extfunc @illegaltype(i42) ; expected-error {{expected type}}
|
||||
extfunc @illegaltype(i) ; expected-error {{expected type}}
|
||||
|
||||
; -----
|
||||
|
||||
|
@ -19,7 +19,7 @@ cfgfunc @bar() ; expected-error {{expected '{' in CFG function}}
|
|||
|
||||
; -----
|
||||
|
||||
extfunc missingsigil() -> (i1, int, f32) ; expected-error {{expected a function identifier like}}
|
||||
extfunc missingsigil() -> (i1, affineint, f32) ; expected-error {{expected a function identifier like}}
|
||||
|
||||
|
||||
; -----
|
||||
|
@ -75,3 +75,9 @@ bb40:
|
|||
""() ; expected-error {{empty operation name is invalid}}
|
||||
return
|
||||
}
|
||||
|
||||
; -----
|
||||
|
||||
extfunc @illegaltype(i0) ; expected-error {{invalid integer width}}
|
||||
|
||||
|
||||
|
|
|
@ -10,25 +10,28 @@ extfunc @foo(i32, i64) -> f32
|
|||
; CHECK: extfunc @bar()
|
||||
extfunc @bar() -> ()
|
||||
|
||||
; CHECK: extfunc @baz() -> (i1, int, f32)
|
||||
extfunc @baz() -> (i1, int, f32)
|
||||
; CHECK: extfunc @baz() -> (i1, affineint, f32)
|
||||
extfunc @baz() -> (i1, affineint, f32)
|
||||
|
||||
; CHECK: extfunc @missingReturn()
|
||||
extfunc @missingReturn()
|
||||
|
||||
; CHECK: extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19)
|
||||
extfunc @int_types(i1, i2, i4, i7, i87) -> (i1, affineint, i19)
|
||||
|
||||
|
||||
; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
|
||||
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
|
||||
|
||||
; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xint>, tensor<i8>)
|
||||
; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xaffineint>, tensor<i8>)
|
||||
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
|
||||
tensor<1x?x4x?x?xint>, tensor<i8>)
|
||||
tensor<1x?x4x?x?xaffineint>, tensor<i8>)
|
||||
|
||||
; CHECK: extfunc @memrefs(i1, i1)
|
||||
extfunc @memrefs(memref<1x?x4x?x?xint>, memref<i8>)
|
||||
extfunc @memrefs(memref<1x?x4x?x?xaffineint>, memref<i8>)
|
||||
|
||||
; CHECK: extfunc @functions((i1, i1) -> (), () -> ())
|
||||
extfunc @functions((memref<1x?x4x?x?xint>, memref<i8>) -> (), ()->())
|
||||
extfunc @functions((memref<1x?x4x?x?xaffineint>, memref<i8>) -> (), ()->())
|
||||
|
||||
|
||||
; CHECK-LABEL: cfgfunc @simpleCFG() {
|
||||
|
|
Loading…
Reference in New Issue