Handle the TF resource data type in the TF/XLA roundtrip pass.

PiperOrigin-RevId: 213650346
This commit is contained in:
Feng Liu 2018-09-19 10:28:46 -07:00 committed by jpienaar
parent 14ca1be9a7
commit 4bc5dc9602
7 changed files with 19 additions and 4 deletions

View File

@ -76,6 +76,7 @@ public:
OtherType *getAffineIntType();
OtherType *getTFControlType();
OtherType *getTFStringType();
OtherType *getTFResourceType();
IntegerType *getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results);

View File

@ -40,6 +40,7 @@ public:
// TensorFlow types.
TFControl,
TFResource,
TFString,
/// These are marker for the first and last 'other' type.
@ -64,9 +65,7 @@ public:
};
/// Return the classification for this type.
Kind getKind() const {
return kind;
}
Kind getKind() const { return kind; }
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const { return context; }
@ -75,6 +74,7 @@ public:
// derived types should use isa/dyn_cast.
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
bool isTFControl() const { return getKind() == Kind::TFControl; }
bool isTFResource() const { return getKind() == Kind::TFResource; }
bool isTFString() const { return getKind() == Kind::TFString; }
bool isBF16() const { return getKind() == Kind::BF16; }
bool isF16() const { return getKind() == Kind::F16; }
@ -93,6 +93,7 @@ public:
static OtherType *getAffineInt(MLIRContext *ctx);
static OtherType *getTFControl(MLIRContext *ctx);
static OtherType *getTFString(MLIRContext *ctx);
static OtherType *getTFResource(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
@ -217,6 +218,9 @@ inline OtherType *Type::getAffineInt(MLIRContext *ctx) {
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx);
}
inline OtherType *Type::getTFResource(MLIRContext *ctx) {
return OtherType::get(Kind::TFResource, ctx);
}
inline OtherType *Type::getTFString(MLIRContext *ctx) {
return OtherType::get(Kind::TFString, ctx);
}

View File

@ -477,6 +477,9 @@ void ModulePrinter::printType(const Type *type) {
case Type::Kind::TFControl:
os << "tf_control";
return;
case Type::Kind::TFResource:
os << "tf_resource";
return;
case Type::Kind::TFString:
os << "tf_string";
return;

View File

@ -64,6 +64,8 @@ OtherType *Builder::getAffineIntType() { return Type::getAffineInt(context); }
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType *Builder::getIntegerType(unsigned width) {

View File

@ -63,7 +63,8 @@ VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type *type, MLIRContext *context) {
return isa<FloatType>(type) || isa<VectorType>(type) ||
isa<IntegerType>(type) || type == Type::getTFString(context);
isa<IntegerType>(type) || type == Type::getTFString(context) ||
type == Type::getTFResource(context);
}
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)

View File

@ -333,6 +333,9 @@ Type *Parser::parseType() {
case Token::kw_tf_control:
consumeToken(Token::kw_tf_control);
return builder.getTFControlType();
case Token::kw_tf_resource:
consumeToken(Token::kw_tf_resource);
return builder.getTFResourceType();
case Token::kw_tf_string:
consumeToken(Token::kw_tf_string);
return builder.getTFStringType();

View File

@ -113,6 +113,7 @@ TOK_KEYWORD(size)
TOK_KEYWORD(step)
TOK_KEYWORD(tensor)
TOK_KEYWORD(tf_control)
TOK_KEYWORD(tf_resource)
TOK_KEYWORD(tf_string)
TOK_KEYWORD(to)
TOK_KEYWORD(true)