Simplify container type definitions

The passed element type description is usually unnecessary, and it's just as valid to want to pass a description for the entire container. In either case there's an alternative (Separate element type def or a TypeAlias) and we don't need to pollute the main API.

    To allow for this, I cleaned up the TF op definitions and added some additional utilities.

--

PiperOrigin-RevId: 249340979
This commit is contained in:
Geoffrey Martin-Noble 2019-05-21 15:44:17 -07:00 committed by Mehdi Amini
parent 3902cef954
commit 7a4869e316
1 changed files with 25 additions and 14 deletions

View File

@ -34,9 +34,9 @@ class StrJoin<list<string> strings, string sep = ", "> {
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur)); !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
} }
// Concatenates a list of integers into a string separated with comma. // Concatenates a list of integers into a string with a separator (default ", ")
class Stringify<list<int> integers> : class StrJoinInt<list<int> integers, string sep = ", "> :
StrJoin<!foreach(i, integers, !cast<string>(i))>; StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Predicate definitions // Predicate definitions
@ -241,6 +241,10 @@ class Dialect {
class Type<Pred condition, string descr = ""> : class Type<Pred condition, string descr = ""> :
TypeConstraint<condition, descr>; TypeConstraint<condition, descr>;
// Allows providing an alternative name and description to an existing type def.
class TypeAlias<Type t, string description = t.description> :
Type<t.predicate, description>;
// A variadic type constraint. It expands to zero or more of the base type. This // A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results. An op can declare no // class is used for supporting variadic operands/results. An op can declare no
// more than one variadic operand/result, and that operand/result must be the // more than one variadic operand/result, and that operand/result must be the
@ -291,6 +295,11 @@ class I<int width>
BuildableType<"getIntegerType(" # width # ")"> { BuildableType<"getIntegerType(" # width # ")"> {
int bitwidth = width; int bitwidth = width;
} }
class IntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, I<w>),
StrJoinInt<widths, "/">.result # "-bit integer">;
def I1 : I<1>; def I1 : I<1>;
def I8 : I<8>; def I8 : I<8>;
def I16 : I<16>; def I16 : I<16>;
@ -310,6 +319,10 @@ class F<int width>
int bitwidth = width; int bitwidth = width;
} }
class FloatOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, F<w>),
StrJoinInt<widths, "/">.result # "-bit float">;
def F16 : F<16>; def F16 : F<16>;
def F32 : F<32>; def F32 : F<32>;
def F64 : F<64>; def F64 : F<64>;
@ -338,24 +351,22 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
code getElementTypeCall = elementTypeCall; code getElementTypeCall = elementTypeCall;
} }
class ShapedContainerType<Type etype, Pred containerPred, string descr> : class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> :
ContainerType<etype, containerPred, ContainerType<AnyTypeOf<allowedTypes>, containerPred,
"$_self.cast<ShapedType>().getElementType()", descr>; "$_self.cast<ShapedType>().getElementType()", descr>;
// Vector types. // Vector types.
class VectorOf<list<Type> allowedTypes, string elementDescription = ""> : class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>, ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
IsVectorTypePred, "vector">;
def AnyVector : VectorOf<[AnyType]>; def AnyVector : VectorOf<[AnyType]>;
// Tensor types. // Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list // Any tensor type whose element type is from the given `allowedTypes` list
class TensorOf<list<Type> allowedTypes, string elementDescription = ""> : class TensorOf<list<Type> allowedTypes> :
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>, ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
IsTensorTypePred, "tensor">;
def AnyTensor : TensorOf<[AnyType]>; def AnyTensor : TensorOf<[AnyType]>;
@ -381,8 +392,8 @@ def F64Tensor : TensorOf<[F64]>;
// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType. // TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
// Memrefs are blocks of data with fixed type and rank. // Memrefs are blocks of data with fixed type and rank.
class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> : class MemRefOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred, ContainerType<AnyTypeOf<allowedTypes>, IsMemRefTypePred,
"$_self.cast<MemRefType>().getElementType()", "memref">; "$_self.cast<MemRefType>().getElementType()", "memref">;
def AnyMemRef : MemRefOf<[AnyType]>; def AnyMemRef : MemRefOf<[AnyType]>;
@ -992,7 +1003,7 @@ class TCopVTEtAreSameAt<list<int> indices> :
CPred<"llvm::is_splat(mlir::functional::map(" CPred<"llvm::is_splat(mlir::functional::map("
"[this](unsigned i) { return this->getOperand(i)->getType()" "[this](unsigned i) { return this->getOperand(i)->getType()"
".cast<ShapedType>().getElementType(); }, " ".cast<ShapedType>().getElementType(); }, "
"llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">; "llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pattern definitions // Pattern definitions