forked from OSchip/llvm-project
Simplify container type definitions
The passed element type description is usually unnecessary, and it's just as valid to want to pass a description for the entire container. In either case there's an alternative (Separate element type def or a TypeAlias) and we don't need to pollute the main API. To allow for this, I cleaned up the TF op definitions and added some additional utilities. -- PiperOrigin-RevId: 249340979
This commit is contained in:
parent
3902cef954
commit
7a4869e316
|
@ -34,9 +34,9 @@ class StrJoin<list<string> strings, string sep = ", "> {
|
||||||
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
|
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Concatenates a list of integers into a string separated with comma.
|
// Concatenates a list of integers into a string with a separator (default ", ")
|
||||||
class Stringify<list<int> integers> :
|
class StrJoinInt<list<int> integers, string sep = ", "> :
|
||||||
StrJoin<!foreach(i, integers, !cast<string>(i))>;
|
StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Predicate definitions
|
// Predicate definitions
|
||||||
|
@ -241,6 +241,10 @@ class Dialect {
|
||||||
class Type<Pred condition, string descr = ""> :
|
class Type<Pred condition, string descr = ""> :
|
||||||
TypeConstraint<condition, descr>;
|
TypeConstraint<condition, descr>;
|
||||||
|
|
||||||
|
// Allows providing an alternative name and description to an existing type def.
|
||||||
|
class TypeAlias<Type t, string description = t.description> :
|
||||||
|
Type<t.predicate, description>;
|
||||||
|
|
||||||
// A variadic type constraint. It expands to zero or more of the base type. This
|
// A variadic type constraint. It expands to zero or more of the base type. This
|
||||||
// class is used for supporting variadic operands/results. An op can declare no
|
// class is used for supporting variadic operands/results. An op can declare no
|
||||||
// more than one variadic operand/result, and that operand/result must be the
|
// more than one variadic operand/result, and that operand/result must be the
|
||||||
|
@ -291,6 +295,11 @@ class I<int width>
|
||||||
BuildableType<"getIntegerType(" # width # ")"> {
|
BuildableType<"getIntegerType(" # width # ")"> {
|
||||||
int bitwidth = width;
|
int bitwidth = width;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class IntOfWidths<list<int> widths> :
|
||||||
|
AnyTypeOf<!foreach(w, widths, I<w>),
|
||||||
|
StrJoinInt<widths, "/">.result # "-bit integer">;
|
||||||
|
|
||||||
def I1 : I<1>;
|
def I1 : I<1>;
|
||||||
def I8 : I<8>;
|
def I8 : I<8>;
|
||||||
def I16 : I<16>;
|
def I16 : I<16>;
|
||||||
|
@ -310,6 +319,10 @@ class F<int width>
|
||||||
int bitwidth = width;
|
int bitwidth = width;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class FloatOfWidths<list<int> widths> :
|
||||||
|
AnyTypeOf<!foreach(w, widths, F<w>),
|
||||||
|
StrJoinInt<widths, "/">.result # "-bit float">;
|
||||||
|
|
||||||
def F16 : F<16>;
|
def F16 : F<16>;
|
||||||
def F32 : F<32>;
|
def F32 : F<32>;
|
||||||
def F64 : F<64>;
|
def F64 : F<64>;
|
||||||
|
@ -338,24 +351,22 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
||||||
code getElementTypeCall = elementTypeCall;
|
code getElementTypeCall = elementTypeCall;
|
||||||
}
|
}
|
||||||
|
|
||||||
class ShapedContainerType<Type etype, Pred containerPred, string descr> :
|
class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> :
|
||||||
ContainerType<etype, containerPred,
|
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
|
||||||
"$_self.cast<ShapedType>().getElementType()", descr>;
|
"$_self.cast<ShapedType>().getElementType()", descr>;
|
||||||
|
|
||||||
// Vector types.
|
// Vector types.
|
||||||
|
|
||||||
class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
class VectorOf<list<Type> allowedTypes> :
|
||||||
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
|
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
|
||||||
IsVectorTypePred, "vector">;
|
|
||||||
|
|
||||||
def AnyVector : VectorOf<[AnyType]>;
|
def AnyVector : VectorOf<[AnyType]>;
|
||||||
|
|
||||||
// Tensor types.
|
// Tensor types.
|
||||||
|
|
||||||
// Any tensor type whose element type is from the given `allowedTypes` list
|
// Any tensor type whose element type is from the given `allowedTypes` list
|
||||||
class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
class TensorOf<list<Type> allowedTypes> :
|
||||||
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
|
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
|
||||||
IsTensorTypePred, "tensor">;
|
|
||||||
|
|
||||||
def AnyTensor : TensorOf<[AnyType]>;
|
def AnyTensor : TensorOf<[AnyType]>;
|
||||||
|
|
||||||
|
@ -381,8 +392,8 @@ def F64Tensor : TensorOf<[F64]>;
|
||||||
|
|
||||||
// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
|
// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
|
||||||
// Memrefs are blocks of data with fixed type and rank.
|
// Memrefs are blocks of data with fixed type and rank.
|
||||||
class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
|
class MemRefOf<list<Type> allowedTypes> :
|
||||||
ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
|
ContainerType<AnyTypeOf<allowedTypes>, IsMemRefTypePred,
|
||||||
"$_self.cast<MemRefType>().getElementType()", "memref">;
|
"$_self.cast<MemRefType>().getElementType()", "memref">;
|
||||||
|
|
||||||
def AnyMemRef : MemRefOf<[AnyType]>;
|
def AnyMemRef : MemRefOf<[AnyType]>;
|
||||||
|
@ -992,7 +1003,7 @@ class TCopVTEtAreSameAt<list<int> indices> :
|
||||||
CPred<"llvm::is_splat(mlir::functional::map("
|
CPred<"llvm::is_splat(mlir::functional::map("
|
||||||
"[this](unsigned i) { return this->getOperand(i)->getType()"
|
"[this](unsigned i) { return this->getOperand(i)->getType()"
|
||||||
".cast<ShapedType>().getElementType(); }, "
|
".cast<ShapedType>().getElementType(); }, "
|
||||||
"llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">;
|
"llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern definitions
|
// Pattern definitions
|
||||||
|
|
Loading…
Reference in New Issue