[TF] Improve verification for integer and floating-point tensor types

TensorFlow does not allow integers of random bitwidths. It only accepts 8-,
16-, 32-, and 64-bit integer types. Similarly for floating point types, only
half, single, double, and bfloat16 types.

PiperOrigin-RevId: 237483913
This commit is contained in:
Lei Zhang 2019-03-08 11:12:32 -08:00 committed by jpienaar
parent 2c78469a93
commit 999a0c8736
1 changed files with 15 additions and 0 deletions

View File

@ -126,6 +126,18 @@ class BuildableType<code builder> {
code builderCall = builder;
}
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string description> : Type<
// Satisfy any of the allowed type's condition
AnyOf<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
!if(!eq(description, ""),
// Join all allowed types' descriptions with " or " as the description
// if not provided during template specialization
!foldl(/*init*/"", /*list*/allowedTypes, prev, cur,
prev # " or " # cur.description),
// Otherwise use the provided one
description)>;
// Integer types.
class IntegerBase<CPred pred, string descr> : Type<pred, descr>;
@ -166,6 +178,9 @@ def F16 : F<16>;
def F32 : F<32>;
def F64 : F<64>;
def BF16 : Type<CPred<"{0}.isBF16()">, "bfloat16 type">,
BuildableType<"getBF16Type()">;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr> :