Simplify container type definitions

The passed element type description is usually unnecessary, and it's just as valid to want to pass a description for the entire container. In either case there's an alternative (Separate element type def or a TypeAlias) and we don't need to pollute the main API.

    To allow for this, I cleaned up the TF op definitions and added some additional utilities.

--

PiperOrigin-RevId: 249340979
This commit is contained in:
Geoffrey Martin-Noble 2019-05-21 15:44:17 -07:00 committed by Mehdi Amini
parent 3902cef954
commit 7a4869e316
1 changed files with 25 additions and 14 deletions

View File

@ -34,9 +34,9 @@ class StrJoin<list<string> strings, string sep = ", "> {
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
}
// Concatenates a list of integers into a string separated with comma.
class Stringify<list<int> integers> :
StrJoin<!foreach(i, integers, !cast<string>(i))>;
// Concatenates a list of integers into a string with a separator (default ", ")
class StrJoinInt<list<int> integers, string sep = ", "> :
StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
//===----------------------------------------------------------------------===//
// Predicate definitions
@ -241,6 +241,10 @@ class Dialect {
class Type<Pred condition, string descr = ""> :
TypeConstraint<condition, descr>;
// Allows providing an alternative name and description to an existing type def.
class TypeAlias<Type t, string description = t.description> :
Type<t.predicate, description>;
// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results. An op can declare no
// more than one variadic operand/result, and that operand/result must be the
@ -291,6 +295,11 @@ class I<int width>
BuildableType<"getIntegerType(" # width # ")"> {
int bitwidth = width;
}
class IntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, I<w>),
StrJoinInt<widths, "/">.result # "-bit integer">;
def I1 : I<1>;
def I8 : I<8>;
def I16 : I<16>;
@ -310,6 +319,10 @@ class F<int width>
int bitwidth = width;
}
class FloatOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, F<w>),
StrJoinInt<widths, "/">.result # "-bit float">;
def F16 : F<16>;
def F32 : F<32>;
def F64 : F<64>;
@ -338,24 +351,22 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
code getElementTypeCall = elementTypeCall;
}
class ShapedContainerType<Type etype, Pred containerPred, string descr> :
ContainerType<etype, containerPred,
class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> :
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
"$_self.cast<ShapedType>().getElementType()", descr>;
// Vector types.
class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
IsVectorTypePred, "vector">;
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
def AnyVector : VectorOf<[AnyType]>;
// Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list
class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
IsTensorTypePred, "tensor">;
class TensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
def AnyTensor : TensorOf<[AnyType]>;
@ -381,8 +392,8 @@ def F64Tensor : TensorOf<[F64]>;
// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
// Memrefs are blocks of data with fixed type and rank.
class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
class MemRefOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsMemRefTypePred,
"$_self.cast<MemRefType>().getElementType()", "memref">;
def AnyMemRef : MemRefOf<[AnyType]>;
@ -992,7 +1003,7 @@ class TCopVTEtAreSameAt<list<int> indices> :
CPred<"llvm::is_splat(mlir::functional::map("
"[this](unsigned i) { return this->getOperand(i)->getType()"
".cast<ShapedType>().getElementType(); }, "
"llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">;
"llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
//===----------------------------------------------------------------------===//
// Pattern definitions