forked from OSchip/llvm-project
Simplify container type definitions
The passed element type description is usually unnecessary, and it's just as valid to want to pass a description for the entire container. In either case there's an alternative (Separate element type def or a TypeAlias) and we don't need to pollute the main API. To allow for this, I cleaned up the TF op definitions and added some additional utilities. -- PiperOrigin-RevId: 249340979
This commit is contained in:
parent
3902cef954
commit
7a4869e316
|
@ -34,9 +34,9 @@ class StrJoin<list<string> strings, string sep = ", "> {
|
|||
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
|
||||
}
|
||||
|
||||
// Concatenates a list of integers into a string separated with comma.
|
||||
class Stringify<list<int> integers> :
|
||||
StrJoin<!foreach(i, integers, !cast<string>(i))>;
|
||||
// Concatenates a list of integers into a string with a separator (default ", ")
|
||||
class StrJoinInt<list<int> integers, string sep = ", "> :
|
||||
StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Predicate definitions
|
||||
|
@ -241,6 +241,10 @@ class Dialect {
|
|||
class Type<Pred condition, string descr = ""> :
|
||||
TypeConstraint<condition, descr>;
|
||||
|
||||
// Allows providing an alternative name and description to an existing type def.
|
||||
class TypeAlias<Type t, string description = t.description> :
|
||||
Type<t.predicate, description>;
|
||||
|
||||
// A variadic type constraint. It expands to zero or more of the base type. This
|
||||
// class is used for supporting variadic operands/results. An op can declare no
|
||||
// more than one variadic operand/result, and that operand/result must be the
|
||||
|
@ -291,6 +295,11 @@ class I<int width>
|
|||
BuildableType<"getIntegerType(" # width # ")"> {
|
||||
int bitwidth = width;
|
||||
}
|
||||
|
||||
class IntOfWidths<list<int> widths> :
|
||||
AnyTypeOf<!foreach(w, widths, I<w>),
|
||||
StrJoinInt<widths, "/">.result # "-bit integer">;
|
||||
|
||||
def I1 : I<1>;
|
||||
def I8 : I<8>;
|
||||
def I16 : I<16>;
|
||||
|
@ -310,6 +319,10 @@ class F<int width>
|
|||
int bitwidth = width;
|
||||
}
|
||||
|
||||
class FloatOfWidths<list<int> widths> :
|
||||
AnyTypeOf<!foreach(w, widths, F<w>),
|
||||
StrJoinInt<widths, "/">.result # "-bit float">;
|
||||
|
||||
def F16 : F<16>;
|
||||
def F32 : F<32>;
|
||||
def F64 : F<64>;
|
||||
|
@ -338,24 +351,22 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
|||
code getElementTypeCall = elementTypeCall;
|
||||
}
|
||||
|
||||
class ShapedContainerType<Type etype, Pred containerPred, string descr> :
|
||||
ContainerType<etype, containerPred,
|
||||
class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
|
||||
"$_self.cast<ShapedType>().getElementType()", descr>;
|
||||
|
||||
// Vector types.
|
||||
|
||||
class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
|
||||
IsVectorTypePred, "vector">;
|
||||
class VectorOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
|
||||
|
||||
def AnyVector : VectorOf<[AnyType]>;
|
||||
|
||||
// Tensor types.
|
||||
|
||||
// Any tensor type whose element type is from the given `allowedTypes` list
|
||||
class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
|
||||
IsTensorTypePred, "tensor">;
|
||||
class TensorOf<list<Type> allowedTypes> :
|
||||
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
|
||||
|
||||
def AnyTensor : TensorOf<[AnyType]>;
|
||||
|
||||
|
@ -381,8 +392,8 @@ def F64Tensor : TensorOf<[F64]>;
|
|||
|
||||
// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
|
||||
// Memrefs are blocks of data with fixed type and rank.
|
||||
class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
|
||||
class MemRefOf<list<Type> allowedTypes> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, IsMemRefTypePred,
|
||||
"$_self.cast<MemRefType>().getElementType()", "memref">;
|
||||
|
||||
def AnyMemRef : MemRefOf<[AnyType]>;
|
||||
|
@ -992,7 +1003,7 @@ class TCopVTEtAreSameAt<list<int> indices> :
|
|||
CPred<"llvm::is_splat(mlir::functional::map("
|
||||
"[this](unsigned i) { return this->getOperand(i)->getType()"
|
||||
".cast<ShapedType>().getElementType(); }, "
|
||||
"llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">;
|
||||
"llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern definitions
|
||||
|
|
Loading…
Reference in New Issue