From 7a4869e3164fea27f9b9ec946b1d4af301ac3fc0 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Tue, 21 May 2019 15:44:17 -0700 Subject: [PATCH] Simplify container type definitions The passed element type description is usually unnecessary, and it's just as valid to want to pass a description for the entire container. In either case there's an alternative (Separate element type def or a TypeAlias) and we don't need to pollute the main API. To allow for this, I cleaned up the TF op definitions and added some additional utilities. -- PiperOrigin-RevId: 249340979 --- mlir/include/mlir/IR/OpBase.td | 39 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index a7b69b63b398..e49e7dc6fdf7 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -34,9 +34,9 @@ class StrJoin strings, string sep = ", "> { !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur)); } -// Concatenates a list of integers into a string separated with comma. -class Stringify integers> : - StrJoin(i))>; +// Concatenates a list of integers into a string with a separator (default ", ") +class StrJoinInt integers, string sep = ", "> : + StrJoin(i)), sep>; //===----------------------------------------------------------------------===// // Predicate definitions @@ -241,6 +241,10 @@ class Dialect { class Type : TypeConstraint; +// Allows providing an alternative name and description to an existing type def. +class TypeAlias : + Type; + // A variadic type constraint. It expands to zero or more of the base type. This // class is used for supporting variadic operands/results. An op can declare no // more than one variadic operand/result, and that operand/result must be the @@ -291,6 +295,11 @@ class I BuildableType<"getIntegerType(" # width # ")"> { int bitwidth = width; } + +class IntOfWidths widths> : + AnyTypeOf), + StrJoinInt.result # "-bit integer">; + def I1 : I<1>; def I8 : I<8>; def I16 : I<16>; @@ -310,6 +319,10 @@ class F int bitwidth = width; } +class FloatOfWidths widths> : + AnyTypeOf), + StrJoinInt.result # "-bit float">; + def F16 : F<16>; def F32 : F<32>; def F64 : F<64>; @@ -338,24 +351,22 @@ class ContainerType : - ContainerType allowedTypes, Pred containerPred, string descr> : + ContainerType, containerPred, "$_self.cast().getElementType()", descr>; // Vector types. -class VectorOf allowedTypes, string elementDescription = ""> : - ShapedContainerType, - IsVectorTypePred, "vector">; +class VectorOf allowedTypes> : + ShapedContainerType; def AnyVector : VectorOf<[AnyType]>; // Tensor types. // Any tensor type whose element type is from the given `allowedTypes` list -class TensorOf allowedTypes, string elementDescription = ""> : - ShapedContainerType, - IsTensorTypePred, "tensor">; +class TensorOf allowedTypes> : + ShapedContainerType; def AnyTensor : TensorOf<[AnyType]>; @@ -381,8 +392,8 @@ def F64Tensor : TensorOf<[F64]>; // TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType. // Memrefs are blocks of data with fixed type and rank. -class MemRefOf allowedTypes, string elementDescription = ""> : - ContainerType, IsMemRefTypePred, +class MemRefOf allowedTypes> : + ContainerType, IsMemRefTypePred, "$_self.cast().getElementType()", "memref">; def AnyMemRef : MemRefOf<[AnyType]>; @@ -992,7 +1003,7 @@ class TCopVTEtAreSameAt indices> : CPred<"llvm::is_splat(mlir::functional::map(" "[this](unsigned i) { return this->getOperand(i)->getType()" ".cast().getElementType(); }, " - "llvm::ArrayRef({" # Stringify.result # "})))">; + "llvm::ArrayRef({" # StrJoinInt.result # "})))">; //===----------------------------------------------------------------------===// // Pattern definitions