forked from OSchip/llvm-project
Supports TF Complex64/Complex128 types in the tf/mlir roundtrip pass.
Alternatively, we can defined a TFComplexType with a width parameter in the mlir, then both types can be converted to the same mlir type with different width (like IntegerType). We chose to use a direct mapping because there are only two TF Complex types. PiperOrigin-RevId: 213856651
This commit is contained in:
parent
aa0309d704
commit
948dea045b
|
@ -78,6 +78,9 @@ public:
|
||||||
OtherType *getTFStringType();
|
OtherType *getTFStringType();
|
||||||
OtherType *getTFResourceType();
|
OtherType *getTFResourceType();
|
||||||
OtherType *getTFVariantType();
|
OtherType *getTFVariantType();
|
||||||
|
OtherType *getTFComplex64Type();
|
||||||
|
OtherType *getTFComplex128Type();
|
||||||
|
|
||||||
IntegerType *getIntegerType(unsigned width);
|
IntegerType *getIntegerType(unsigned width);
|
||||||
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
|
||||||
ArrayRef<Type *> results);
|
ArrayRef<Type *> results);
|
||||||
|
|
|
@ -42,6 +42,8 @@ public:
|
||||||
TFControl,
|
TFControl,
|
||||||
TFResource,
|
TFResource,
|
||||||
TFVariant,
|
TFVariant,
|
||||||
|
TFComplex64,
|
||||||
|
TFComplex128,
|
||||||
TFString,
|
TFString,
|
||||||
|
|
||||||
/// These are marker for the first and last 'other' type.
|
/// These are marker for the first and last 'other' type.
|
||||||
|
@ -77,6 +79,8 @@ public:
|
||||||
bool isTFControl() const { return getKind() == Kind::TFControl; }
|
bool isTFControl() const { return getKind() == Kind::TFControl; }
|
||||||
bool isTFResource() const { return getKind() == Kind::TFResource; }
|
bool isTFResource() const { return getKind() == Kind::TFResource; }
|
||||||
bool isTFVariant() const { return getKind() == Kind::TFVariant; }
|
bool isTFVariant() const { return getKind() == Kind::TFVariant; }
|
||||||
|
bool isTFComplex64() const { return getKind() == Kind::TFComplex64; }
|
||||||
|
bool isTFComplex128() const { return getKind() == Kind::TFComplex128; }
|
||||||
bool isTFString() const { return getKind() == Kind::TFString; }
|
bool isTFString() const { return getKind() == Kind::TFString; }
|
||||||
bool isBF16() const { return getKind() == Kind::BF16; }
|
bool isBF16() const { return getKind() == Kind::BF16; }
|
||||||
bool isF16() const { return getKind() == Kind::F16; }
|
bool isF16() const { return getKind() == Kind::F16; }
|
||||||
|
@ -97,6 +101,8 @@ public:
|
||||||
static OtherType *getTFString(MLIRContext *ctx);
|
static OtherType *getTFString(MLIRContext *ctx);
|
||||||
static OtherType *getTFResource(MLIRContext *ctx);
|
static OtherType *getTFResource(MLIRContext *ctx);
|
||||||
static OtherType *getTFVariant(MLIRContext *ctx);
|
static OtherType *getTFVariant(MLIRContext *ctx);
|
||||||
|
static OtherType *getTFComplex64(MLIRContext *ctx);
|
||||||
|
static OtherType *getTFComplex128(MLIRContext *ctx);
|
||||||
|
|
||||||
/// Print the current type.
|
/// Print the current type.
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
|
@ -230,6 +236,12 @@ inline OtherType *Type::getTFString(MLIRContext *ctx) {
|
||||||
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
|
inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
|
||||||
return OtherType::get(Kind::TFVariant, ctx);
|
return OtherType::get(Kind::TFVariant, ctx);
|
||||||
}
|
}
|
||||||
|
inline OtherType *Type::getTFComplex64(MLIRContext *ctx) {
|
||||||
|
return OtherType::get(Kind::TFComplex64, ctx);
|
||||||
|
}
|
||||||
|
inline OtherType *Type::getTFComplex128(MLIRContext *ctx) {
|
||||||
|
return OtherType::get(Kind::TFComplex128, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
/// Function types map from a list of inputs to a list of results.
|
/// Function types map from a list of inputs to a list of results.
|
||||||
class FunctionType : public Type {
|
class FunctionType : public Type {
|
||||||
|
|
|
@ -483,6 +483,12 @@ void ModulePrinter::printType(const Type *type) {
|
||||||
case Type::Kind::TFVariant:
|
case Type::Kind::TFVariant:
|
||||||
os << "tf_variant";
|
os << "tf_variant";
|
||||||
return;
|
return;
|
||||||
|
case Type::Kind::TFComplex64:
|
||||||
|
os << "tf_complex64";
|
||||||
|
return;
|
||||||
|
case Type::Kind::TFComplex128:
|
||||||
|
os << "tf_complex128";
|
||||||
|
return;
|
||||||
case Type::Kind::TFString:
|
case Type::Kind::TFString:
|
||||||
os << "tf_string";
|
os << "tf_string";
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -68,6 +68,14 @@ OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||||
|
|
||||||
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
||||||
|
|
||||||
|
OtherType *Builder::getTFComplex64Type() {
|
||||||
|
return Type::getTFComplex64(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
OtherType *Builder::getTFComplex128Type() {
|
||||||
|
return Type::getTFComplex128(context);
|
||||||
|
}
|
||||||
|
|
||||||
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
|
||||||
|
|
||||||
IntegerType *Builder::getIntegerType(unsigned width) {
|
IntegerType *Builder::getIntegerType(unsigned width) {
|
||||||
|
|
|
@ -339,6 +339,12 @@ Type *Parser::parseType() {
|
||||||
case Token::kw_tf_variant:
|
case Token::kw_tf_variant:
|
||||||
consumeToken(Token::kw_tf_variant);
|
consumeToken(Token::kw_tf_variant);
|
||||||
return builder.getTFVariantType();
|
return builder.getTFVariantType();
|
||||||
|
case Token::kw_tf_complex64:
|
||||||
|
consumeToken(Token::kw_tf_complex64);
|
||||||
|
return builder.getTFComplex64Type();
|
||||||
|
case Token::kw_tf_complex128:
|
||||||
|
consumeToken(Token::kw_tf_complex128);
|
||||||
|
return builder.getTFComplex128Type();
|
||||||
case Token::kw_tf_string:
|
case Token::kw_tf_string:
|
||||||
consumeToken(Token::kw_tf_string);
|
consumeToken(Token::kw_tf_string);
|
||||||
return builder.getTFStringType();
|
return builder.getTFStringType();
|
||||||
|
|
|
@ -115,6 +115,8 @@ TOK_KEYWORD(tensor)
|
||||||
TOK_KEYWORD(tf_control)
|
TOK_KEYWORD(tf_control)
|
||||||
TOK_KEYWORD(tf_resource)
|
TOK_KEYWORD(tf_resource)
|
||||||
TOK_KEYWORD(tf_variant)
|
TOK_KEYWORD(tf_variant)
|
||||||
|
TOK_KEYWORD(tf_complex64)
|
||||||
|
TOK_KEYWORD(tf_complex128)
|
||||||
TOK_KEYWORD(tf_string)
|
TOK_KEYWORD(tf_string)
|
||||||
TOK_KEYWORD(to)
|
TOK_KEYWORD(to)
|
||||||
TOK_KEYWORD(true)
|
TOK_KEYWORD(true)
|
||||||
|
|
Loading…
Reference in New Issue