diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 4fb61573b309..b71808476fef 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -76,6 +76,7 @@ public: OtherType *getAffineIntType(); OtherType *getTFControlType(); OtherType *getTFStringType(); + OtherType *getTFResourceType(); IntegerType *getIntegerType(unsigned width); FunctionType *getFunctionType(ArrayRef inputs, ArrayRef results); diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index c7e4543ec742..04a1e26f7ef5 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -40,6 +40,7 @@ public: // TensorFlow types. TFControl, + TFResource, TFString, /// These are marker for the first and last 'other' type. @@ -64,9 +65,7 @@ public: }; /// Return the classification for this type. - Kind getKind() const { - return kind; - } + Kind getKind() const { return kind; } /// Return the LLVMContext in which this type was uniqued. MLIRContext *getContext() const { return context; } @@ -75,6 +74,7 @@ public: // derived types should use isa/dyn_cast. bool isAffineInt() const { return getKind() == Kind::AffineInt; } bool isTFControl() const { return getKind() == Kind::TFControl; } + bool isTFResource() const { return getKind() == Kind::TFResource; } bool isTFString() const { return getKind() == Kind::TFString; } bool isBF16() const { return getKind() == Kind::BF16; } bool isF16() const { return getKind() == Kind::F16; } @@ -93,6 +93,7 @@ public: static OtherType *getAffineInt(MLIRContext *ctx); static OtherType *getTFControl(MLIRContext *ctx); static OtherType *getTFString(MLIRContext *ctx); + static OtherType *getTFResource(MLIRContext *ctx); /// Print the current type. void print(raw_ostream &os) const; @@ -217,6 +218,9 @@ inline OtherType *Type::getAffineInt(MLIRContext *ctx) { inline OtherType *Type::getTFControl(MLIRContext *ctx) { return OtherType::get(Kind::TFControl, ctx); } +inline OtherType *Type::getTFResource(MLIRContext *ctx) { + return OtherType::get(Kind::TFResource, ctx); +} inline OtherType *Type::getTFString(MLIRContext *ctx) { return OtherType::get(Kind::TFString, ctx); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5737b8aa00e5..e9104472ed66 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -477,6 +477,9 @@ void ModulePrinter::printType(const Type *type) { case Type::Kind::TFControl: os << "tf_control"; return; + case Type::Kind::TFResource: + os << "tf_resource"; + return; case Type::Kind::TFString: os << "tf_string"; return; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 5ded7e8a6522..bc8f82fff140 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -64,6 +64,8 @@ OtherType *Builder::getAffineIntType() { return Type::getAffineInt(context); } OtherType *Builder::getTFControlType() { return Type::getTFControl(context); } +OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); } + OtherType *Builder::getTFStringType() { return Type::getTFString(context); } IntegerType *Builder::getIntegerType(unsigned width) { diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 02b7d3542821..400ef0037228 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -63,7 +63,8 @@ VectorType::VectorType(ArrayRef shape, Type *elementType, /// Return true if the specified element type is ok in a tensor. static bool isValidTensorElementType(Type *type, MLIRContext *context) { return isa(type) || isa(type) || - isa(type) || type == Type::getTFString(context); + isa(type) || type == Type::getTFString(context) || + type == Type::getTFResource(context); } TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 2302676ac393..e68e86cae0f6 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -333,6 +333,9 @@ Type *Parser::parseType() { case Token::kw_tf_control: consumeToken(Token::kw_tf_control); return builder.getTFControlType(); + case Token::kw_tf_resource: + consumeToken(Token::kw_tf_resource); + return builder.getTFResourceType(); case Token::kw_tf_string: consumeToken(Token::kw_tf_string); return builder.getTFStringType(); diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 3ef732822f49..1a98eed90c3d 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -113,6 +113,7 @@ TOK_KEYWORD(size) TOK_KEYWORD(step) TOK_KEYWORD(tensor) TOK_KEYWORD(tf_control) +TOK_KEYWORD(tf_resource) TOK_KEYWORD(tf_string) TOK_KEYWORD(to) TOK_KEYWORD(true)