forked from OSchip/llvm-project
[TF] Improve verification for integer and floating-point tensor types
TensorFlow does not allow integers of random bitwidths. It only accepts 8-, 16-, 32-, and 64-bit integer types. Similarly for floating point types, only half, single, double, and bfloat16 types. PiperOrigin-RevId: 237483913
This commit is contained in:
parent
2c78469a93
commit
999a0c8736
|
@ -126,6 +126,18 @@ class BuildableType<code builder> {
|
|||
code builderCall = builder;
|
||||
}
|
||||
|
||||
// Any type from the given list
|
||||
class AnyTypeOf<list<Type> allowedTypes, string description> : Type<
|
||||
// Satisfy any of the allowed type's condition
|
||||
AnyOf<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
|
||||
!if(!eq(description, ""),
|
||||
// Join all allowed types' descriptions with " or " as the description
|
||||
// if not provided during template specialization
|
||||
!foldl(/*init*/"", /*list*/allowedTypes, prev, cur,
|
||||
prev # " or " # cur.description),
|
||||
// Otherwise use the provided one
|
||||
description)>;
|
||||
|
||||
// Integer types.
|
||||
class IntegerBase<CPred pred, string descr> : Type<pred, descr>;
|
||||
|
||||
|
@ -166,6 +178,9 @@ def F16 : F<16>;
|
|||
def F32 : F<32>;
|
||||
def F64 : F<64>;
|
||||
|
||||
def BF16 : Type<CPred<"{0}.isBF16()">, "bfloat16 type">,
|
||||
BuildableType<"getBF16Type()">;
|
||||
|
||||
// A container type is a type that has another type embedded within it.
|
||||
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
||||
string descr> :
|
||||
|
|
Loading…
Reference in New Issue