forked from OSchip/llvm-project
Support TF Variant type in the tf/mlir roundtrip pass.
PiperOrigin-RevId: 213748573
This commit is contained in:
parent
4bc5dc9602
commit
5f69643cbf
|
@ -77,6 +77,7 @@ public:
|
||||||
OtherType *getTFControlType();
|
OtherType *getTFControlType();
|
||||||
OtherType *getTFStringType();
|
OtherType *getTFStringType();
|
||||||
OtherType *getTFResourceType();
|
OtherType *getTFResourceType();
|
||||||
|
OtherType *getTFVariantType();
|
||||||
IntegerType *getIntegerType(unsigned width);
|
IntegerType *getIntegerType(unsigned width);
|
||||||
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
||||||
ArrayRef<Type *> results);
|
ArrayRef<Type *> results);
|
||||||
|
|
|
@ -41,6 +41,7 @@ public:
|
||||||
// TensorFlow types.
|
// TensorFlow types.
|
||||||
TFControl,
|
TFControl,
|
||||||
TFResource,
|
TFResource,
|
||||||
|
TFVariant,
|
||||||
TFString,
|
TFString,
|
||||||
|
|
||||||
/// These are marker for the first and last 'other' type.
|
/// These are marker for the first and last 'other' type.
|
||||||
|
@ -75,6 +76,7 @@ public:
|
||||||
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
|
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
|
||||||
bool isTFControl() const { return getKind() == Kind::TFControl; }
|
bool isTFControl() const { return getKind() == Kind::TFControl; }
|
||||||
bool isTFResource() const { return getKind() == Kind::TFResource; }
|
bool isTFResource() const { return getKind() == Kind::TFResource; }
|
||||||
|
bool isTFVariant() const { return getKind() == Kind::TFVariant; }
|
||||||
bool isTFString() const { return getKind() == Kind::TFString; }
|
bool isTFString() const { return getKind() == Kind::TFString; }
|
||||||
bool isBF16() const { return getKind() == Kind::BF16; }
|
bool isBF16() const { return getKind() == Kind::BF16; }
|
||||||
bool isF16() const { return getKind() == Kind::F16; }
|
bool isF16() const { return getKind() == Kind::F16; }
|
||||||
|
@ -94,6 +96,7 @@ public:
|
||||||
static OtherType *getTFControl(MLIRContext *ctx);
|
static OtherType *getTFControl(MLIRContext *ctx);
|
||||||
static OtherType *getTFString(MLIRContext *ctx);
|
static OtherType *getTFString(MLIRContext *ctx);
|
||||||
static OtherType *getTFResource(MLIRContext *ctx);
|
static OtherType *getTFResource(MLIRContext *ctx);
|
||||||
|
static OtherType *getTFVariant(MLIRContext *ctx);
|
||||||
|
|
||||||
/// Print the current type.
|
/// Print the current type.
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
|
@ -224,6 +227,9 @@ inline OtherType *Type::getTFResource(MLIRContext *ctx) {
|
||||||
inline OtherType *Type::getTFString(MLIRContext *ctx) {
|
inline OtherType *Type::getTFString(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFString, ctx);
|
return OtherType::get(Kind::TFString, ctx);
|
||||||
}
|
}
|
||||||
|
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
|
||||||
|
return OtherType::get(Kind::TFVariant, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
/// Function types map from a list of inputs to a list of results.
|
/// Function types map from a list of inputs to a list of results.
|
||||||
class FunctionType : public Type {
|
class FunctionType : public Type {
|
||||||
|
@ -432,6 +438,12 @@ private:
|
||||||
~MemRefType() = delete;
|
~MemRefType() = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Return true if the specified element type is ok in a tensor.
|
||||||
|
static bool isValidTensorElementType(Type *type) {
|
||||||
|
return isa<FloatType>(type) || isa<VectorType>(type) ||
|
||||||
|
isa<IntegerType>(type) || isa<OtherType>(type);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_IR_TYPES_H
|
#endif // MLIR_IR_TYPES_H
|
||||||
|
|
|
@ -480,6 +480,9 @@ void ModulePrinter::printType(const Type *type) {
|
||||||
case Type::Kind::TFResource:
|
case Type::Kind::TFResource:
|
||||||
os << "tf_resource";
|
os << "tf_resource";
|
||||||
return;
|
return;
|
||||||
|
case Type::Kind::TFVariant:
|
||||||
|
os << "tf_variant";
|
||||||
|
return;
|
||||||
case Type::Kind::TFString:
|
case Type::Kind::TFString:
|
||||||
os << "tf_string";
|
os << "tf_string";
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -66,6 +66,8 @@ OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
|
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||||
|
|
||||||
|
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
||||||
|
|
||||||
IntegerType *Builder::getIntegerType(unsigned width) {
|
IntegerType *Builder::getIntegerType(unsigned width) {
|
||||||
|
|
|
@ -60,16 +60,9 @@ VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
|
||||||
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
||||||
shapeElements(shape.data()) {}
|
shapeElements(shape.data()) {}
|
||||||
|
|
||||||
/// Return true if the specified element type is ok in a tensor.
|
|
||||||
static bool isValidTensorElementType(Type *type, MLIRContext *context) {
|
|
||||||
return isa<FloatType>(type) || isa<VectorType>(type) ||
|
|
||||||
isa<IntegerType>(type) || type == Type::getTFString(context) ||
|
|
||||||
type == Type::getTFResource(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
|
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
|
||||||
: VectorOrTensorType(kind, context, elementType) {
|
: VectorOrTensorType(kind, context, elementType) {
|
||||||
assert(isValidTensorElementType(elementType, context));
|
assert(isValidTensorElementType(elementType));
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||||
|
|
|
@ -336,6 +336,9 @@ Type *Parser::parseType() {
|
||||||
case Token::kw_tf_resource:
|
case Token::kw_tf_resource:
|
||||||
consumeToken(Token::kw_tf_resource);
|
consumeToken(Token::kw_tf_resource);
|
||||||
return builder.getTFResourceType();
|
return builder.getTFResourceType();
|
||||||
|
case Token::kw_tf_variant:
|
||||||
|
consumeToken(Token::kw_tf_variant);
|
||||||
|
return builder.getTFVariantType();
|
||||||
case Token::kw_tf_string:
|
case Token::kw_tf_string:
|
||||||
consumeToken(Token::kw_tf_string);
|
consumeToken(Token::kw_tf_string);
|
||||||
return builder.getTFStringType();
|
return builder.getTFStringType();
|
||||||
|
@ -468,8 +471,7 @@ Type *Parser::parseTensorType() {
|
||||||
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
|
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
|
if (!isValidTensorElementType(elementType))
|
||||||
!isa<VectorType>(elementType))
|
|
||||||
return (emitError(typeLoc, "invalid tensor element type"), nullptr);
|
return (emitError(typeLoc, "invalid tensor element type"), nullptr);
|
||||||
|
|
||||||
if (isUnranked)
|
if (isUnranked)
|
||||||
|
|
|
@ -114,6 +114,7 @@ TOK_KEYWORD(step)
|
||||||
TOK_KEYWORD(tensor)
|
TOK_KEYWORD(tensor)
|
||||||
TOK_KEYWORD(tf_control)
|
TOK_KEYWORD(tf_control)
|
||||||
TOK_KEYWORD(tf_resource)
|
TOK_KEYWORD(tf_resource)
|
||||||
|
TOK_KEYWORD(tf_variant)
|
||||||
TOK_KEYWORD(tf_string)
|
TOK_KEYWORD(tf_string)
|
||||||
TOK_KEYWORD(to)
|
TOK_KEYWORD(to)
|
||||||
TOK_KEYWORD(true)
|
TOK_KEYWORD(true)
|
||||||
|
|
Loading…
Reference in New Issue