Support TF Variant type in the tf/mlir roundtrip pass.

PiperOrigin-RevId: 213748573
This commit is contained in:
Feng Liu 2018-09-19 21:15:43 -07:00 committed by jpienaar
parent 4bc5dc9602
commit 5f69643cbf
7 changed files with 24 additions and 10 deletions

View File

@ -77,6 +77,7 @@ public:
OtherType *getTFControlType(); OtherType *getTFControlType();
OtherType *getTFStringType(); OtherType *getTFStringType();
OtherType *getTFResourceType(); OtherType *getTFResourceType();
OtherType *getTFVariantType();
IntegerType *getIntegerType(unsigned width); IntegerType *getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs, FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results); ArrayRef<Type *> results);

View File

@ -41,6 +41,7 @@ public:
// TensorFlow types. // TensorFlow types.
TFControl, TFControl,
TFResource, TFResource,
TFVariant,
TFString, TFString,
/// These are marker for the first and last 'other' type. /// These are marker for the first and last 'other' type.
@ -75,6 +76,7 @@ public:
bool isAffineInt() const { return getKind() == Kind::AffineInt; } bool isAffineInt() const { return getKind() == Kind::AffineInt; }
bool isTFControl() const { return getKind() == Kind::TFControl; } bool isTFControl() const { return getKind() == Kind::TFControl; }
bool isTFResource() const { return getKind() == Kind::TFResource; } bool isTFResource() const { return getKind() == Kind::TFResource; }
bool isTFVariant() const { return getKind() == Kind::TFVariant; }
bool isTFString() const { return getKind() == Kind::TFString; } bool isTFString() const { return getKind() == Kind::TFString; }
bool isBF16() const { return getKind() == Kind::BF16; } bool isBF16() const { return getKind() == Kind::BF16; }
bool isF16() const { return getKind() == Kind::F16; } bool isF16() const { return getKind() == Kind::F16; }
@ -94,6 +96,7 @@ public:
static OtherType *getTFControl(MLIRContext *ctx); static OtherType *getTFControl(MLIRContext *ctx);
static OtherType *getTFString(MLIRContext *ctx); static OtherType *getTFString(MLIRContext *ctx);
static OtherType *getTFResource(MLIRContext *ctx); static OtherType *getTFResource(MLIRContext *ctx);
static OtherType *getTFVariant(MLIRContext *ctx);
/// Print the current type. /// Print the current type.
void print(raw_ostream &os) const; void print(raw_ostream &os) const;
@ -224,6 +227,9 @@ inline OtherType *Type::getTFResource(MLIRContext *ctx) {
inline OtherType *Type::getTFString(MLIRContext *ctx) { inline OtherType *Type::getTFString(MLIRContext *ctx) {
return OtherType::get(Kind::TFString, ctx); return OtherType::get(Kind::TFString, ctx);
} }
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
return OtherType::get(Kind::TFVariant, ctx);
}
/// Function types map from a list of inputs to a list of results. /// Function types map from a list of inputs to a list of results.
class FunctionType : public Type { class FunctionType : public Type {
@ -432,6 +438,12 @@ private:
~MemRefType() = delete; ~MemRefType() = delete;
}; };
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type *type) {
return isa<FloatType>(type) || isa<VectorType>(type) ||
isa<IntegerType>(type) || isa<OtherType>(type);
}
} // end namespace mlir } // end namespace mlir
#endif // MLIR_IR_TYPES_H #endif // MLIR_IR_TYPES_H

View File

@ -480,6 +480,9 @@ void ModulePrinter::printType(const Type *type) {
case Type::Kind::TFResource: case Type::Kind::TFResource:
os << "tf_resource"; os << "tf_resource";
return; return;
case Type::Kind::TFVariant:
os << "tf_variant";
return;
case Type::Kind::TFString: case Type::Kind::TFString:
os << "tf_string"; os << "tf_string";
return; return;

View File

@ -66,6 +66,8 @@ OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); } OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
OtherType *Builder::getTFStringType() { return Type::getTFString(context); } OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType *Builder::getIntegerType(unsigned width) { IntegerType *Builder::getIntegerType(unsigned width) {

View File

@ -60,16 +60,9 @@ VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()), : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
shapeElements(shape.data()) {} shapeElements(shape.data()) {}
/// Return true if the specified element type is ok in a tensor.
static bool isValidTensorElementType(Type *type, MLIRContext *context) {
return isa<FloatType>(type) || isa<VectorType>(type) ||
isa<IntegerType>(type) || type == Type::getTFString(context) ||
type == Type::getTFResource(context);
}
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: VectorOrTensorType(kind, context, elementType) { : VectorOrTensorType(kind, context, elementType) {
assert(isValidTensorElementType(elementType, context)); assert(isValidTensorElementType(elementType));
} }
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType, RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,

View File

@ -336,6 +336,9 @@ Type *Parser::parseType() {
case Token::kw_tf_resource: case Token::kw_tf_resource:
consumeToken(Token::kw_tf_resource); consumeToken(Token::kw_tf_resource);
return builder.getTFResourceType(); return builder.getTFResourceType();
case Token::kw_tf_variant:
consumeToken(Token::kw_tf_variant);
return builder.getTFVariantType();
case Token::kw_tf_string: case Token::kw_tf_string:
consumeToken(Token::kw_tf_string); consumeToken(Token::kw_tf_string);
return builder.getTFStringType(); return builder.getTFStringType();
@ -468,8 +471,7 @@ Type *Parser::parseTensorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr; return nullptr;
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) && if (!isValidTensorElementType(elementType))
!isa<VectorType>(elementType))
return (emitError(typeLoc, "invalid tensor element type"), nullptr); return (emitError(typeLoc, "invalid tensor element type"), nullptr);
if (isUnranked) if (isUnranked)

View File

@ -114,6 +114,7 @@ TOK_KEYWORD(step)
TOK_KEYWORD(tensor) TOK_KEYWORD(tensor)
TOK_KEYWORD(tf_control) TOK_KEYWORD(tf_control)
TOK_KEYWORD(tf_resource) TOK_KEYWORD(tf_resource)
TOK_KEYWORD(tf_variant)
TOK_KEYWORD(tf_string) TOK_KEYWORD(tf_string)
TOK_KEYWORD(to) TOK_KEYWORD(to)
TOK_KEYWORD(true) TOK_KEYWORD(true)