forked from OSchip/llvm-project
Handle the TF resource data type in the TF/XLA roundtrip pass.
PiperOrigin-RevId: 213650346
This commit is contained in:
parent
14ca1be9a7
commit
4bc5dc9602
|
@ -76,6 +76,7 @@ public:
|
|||
OtherType *getAffineIntType();
|
||||
OtherType *getTFControlType();
|
||||
OtherType *getTFStringType();
|
||||
OtherType *getTFResourceType();
|
||||
IntegerType *getIntegerType(unsigned width);
|
||||
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
||||
ArrayRef<Type *> results);
|
||||
|
|
|
@ -40,6 +40,7 @@ public:
|
|||
|
||||
// TensorFlow types.
|
||||
TFControl,
|
||||
TFResource,
|
||||
TFString,
|
||||
|
||||
/// These are marker for the first and last 'other' type.
|
||||
|
@ -64,9 +65,7 @@ public:
|
|||
};
|
||||
|
||||
/// Return the classification for this type.
|
||||
Kind getKind() const {
|
||||
return kind;
|
||||
}
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
/// Return the LLVMContext in which this type was uniqued.
|
||||
MLIRContext *getContext() const { return context; }
|
||||
|
@ -75,6 +74,7 @@ public:
|
|||
// derived types should use isa/dyn_cast.
|
||||
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
|
||||
bool isTFControl() const { return getKind() == Kind::TFControl; }
|
||||
bool isTFResource() const { return getKind() == Kind::TFResource; }
|
||||
bool isTFString() const { return getKind() == Kind::TFString; }
|
||||
bool isBF16() const { return getKind() == Kind::BF16; }
|
||||
bool isF16() const { return getKind() == Kind::F16; }
|
||||
|
@ -93,6 +93,7 @@ public:
|
|||
static OtherType *getAffineInt(MLIRContext *ctx);
|
||||
static OtherType *getTFControl(MLIRContext *ctx);
|
||||
static OtherType *getTFString(MLIRContext *ctx);
|
||||
static OtherType *getTFResource(MLIRContext *ctx);
|
||||
|
||||
/// Print the current type.
|
||||
void print(raw_ostream &os) const;
|
||||
|
@ -217,6 +218,9 @@ inline OtherType *Type::getAffineInt(MLIRContext *ctx) {
|
|||
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFControl, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFResource(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFResource, ctx);
|
||||
}
|
||||
inline OtherType *Type::getTFString(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFString, ctx);
|
||||
}
|
||||
|
|
|
@ -477,6 +477,9 @@ void ModulePrinter::printType(const Type *type) {
|
|||
case Type::Kind::TFControl:
|
||||
os << "tf_control";
|
||||
return;
|
||||
case Type::Kind::TFResource:
|
||||
os << "tf_resource";
|
||||
return;
|
||||
case Type::Kind::TFString:
|
||||
os << "tf_string";
|
||||
return;
|
||||
|
|
|
@ -64,6 +64,8 @@ OtherType *Builder::getAffineIntType() { return Type::getAffineInt(context); }
|
|||
|
||||
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
|
||||
|
||||
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||
|
||||
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
||||
|
||||
IntegerType *Builder::getIntegerType(unsigned width) {
|
||||
|
|
|
@ -63,7 +63,8 @@ VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
|
|||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidTensorElementType(Type *type, MLIRContext *context) {
|
||||
return isa<FloatType>(type) || isa<VectorType>(type) ||
|
||||
isa<IntegerType>(type) || type == Type::getTFString(context);
|
||||
isa<IntegerType>(type) || type == Type::getTFString(context) ||
|
||||
type == Type::getTFResource(context);
|
||||
}
|
||||
|
||||
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
|
||||
|
|
|
@ -333,6 +333,9 @@ Type *Parser::parseType() {
|
|||
case Token::kw_tf_control:
|
||||
consumeToken(Token::kw_tf_control);
|
||||
return builder.getTFControlType();
|
||||
case Token::kw_tf_resource:
|
||||
consumeToken(Token::kw_tf_resource);
|
||||
return builder.getTFResourceType();
|
||||
case Token::kw_tf_string:
|
||||
consumeToken(Token::kw_tf_string);
|
||||
return builder.getTFStringType();
|
||||
|
|
|
@ -113,6 +113,7 @@ TOK_KEYWORD(size)
|
|||
TOK_KEYWORD(step)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(tf_control)
|
||||
TOK_KEYWORD(tf_resource)
|
||||
TOK_KEYWORD(tf_string)
|
||||
TOK_KEYWORD(to)
|
||||
TOK_KEYWORD(true)
|
||||
|
|
Loading…
Reference in New Issue