forked from OSchip/llvm-project
Clean up container type names in OpBase
Establish the following convention: 1. Container class types end in "Of" (e.g. TensorOf) and take a list of allowed types. 2. An X container where only a single type is allowed is called TypeX (e.g. I32Tensor). 3. An X container where any type is allowed is called AnyX (e.g. AnyTensor). -- PiperOrigin-RevId: 249281018
This commit is contained in:
parent
6aae7b2e9a
commit
da37b0a536
|
@ -450,7 +450,7 @@ the entity's declaration place as described in
|
|||
|
||||
To help modelling constraints of common types, a set of `TypeConstraint`s are
|
||||
created; they are the `Type` subclass hierarchy. It includes `F32` for the
|
||||
constraints of being a float, `TypedTensor<F32>` for the constraints of being
|
||||
constraints of being a float, `TensorOf<[F32]>` for the constraints of being
|
||||
a float tensor, and so on.
|
||||
|
||||
Similarly, a set of `AttrConstraint`s are created for helping modelling
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
|
||||
class quant_TypedPrimitiveOrContainer<Type etype> :
|
||||
Type<Or<[etype.predicate,
|
||||
Tensor<etype>.predicate,
|
||||
Vector<etype>.predicate]>,
|
||||
TensorOf<[etype]>.predicate,
|
||||
VectorOf<[etype]>.predicate]>,
|
||||
"primitive/tensor/vector of " # etype.description>;
|
||||
|
||||
// An implementation of QuantizedType.
|
||||
|
|
|
@ -344,49 +344,72 @@ class ShapedContainerType<Type etype, Pred containerPred, string descr> :
|
|||
|
||||
// Vector types.
|
||||
|
||||
class Vector<Type t> : ShapedContainerType<t, IsVectorTypePred, "vector">;
|
||||
class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
|
||||
IsVectorTypePred, "vector">;
|
||||
|
||||
def AnyVector : Vector<AnyType>;
|
||||
def AnyVector : VectorOf<[AnyType]>;
|
||||
|
||||
// Tensor types.
|
||||
|
||||
class Tensor<Type t> : ShapedContainerType<t, IsTensorTypePred, "tensor">;
|
||||
|
||||
def AnyTensor : Tensor<AnyType>;
|
||||
|
||||
// Any tensor type whose element type is from the given `allowedTypes` list
|
||||
class AnyTensorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
Tensor<AnyTypeOf<allowedTypes, elementDescription>>;
|
||||
class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
|
||||
IsTensorTypePred, "tensor">;
|
||||
|
||||
def AnyTensor : TensorOf<[AnyType]>;
|
||||
|
||||
// TODO(b/130807343) Fix description to contain element information.
|
||||
class StaticShapeTensor<Type t>
|
||||
: Type<And<[ Tensor<t>.predicate, HasStaticShapePred ]>,
|
||||
: Type<And<[ TensorOf<[t]>.predicate, HasStaticShapePred ]>,
|
||||
"statically shaped tensor">;
|
||||
|
||||
def AnyStaticShapeTensor : StaticShapeTensor<AnyType>;
|
||||
|
||||
def I1Tensor : Tensor<I1>;
|
||||
def I8Tensor : Tensor<I8>;
|
||||
def I16Tensor : Tensor<I16>;
|
||||
def I32Tensor : Tensor<I32>;
|
||||
def I64Tensor : Tensor<I64>;
|
||||
def I1Tensor : TensorOf<[I1]>;
|
||||
def I8Tensor : TensorOf<[I8]>;
|
||||
def I16Tensor : TensorOf<[I16]>;
|
||||
def I32Tensor : TensorOf<[I32]>;
|
||||
def I64Tensor : TensorOf<[I64]>;
|
||||
|
||||
def BF16Tensor : Tensor<BF16>;
|
||||
def F16Tensor : Tensor<F16>;
|
||||
def F32Tensor : Tensor<F32>;
|
||||
def F64Tensor : Tensor<F64>;
|
||||
def BF16Tensor : TensorOf<[BF16]>;
|
||||
def F16Tensor : TensorOf<[F16]>;
|
||||
def F32Tensor : TensorOf<[F32]>;
|
||||
def F64Tensor : TensorOf<[F64]>;
|
||||
|
||||
// This represents a generic tuple without any constraints on elemental type,
|
||||
// ranks, or size. As Tuples can contain tensors, vectors, or scalar values
|
||||
// there is not only a single elemental type.
|
||||
def Tuple : Type<IsTupleTypePred, "tuple">;
|
||||
// Memref type.
|
||||
|
||||
// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
|
||||
// Memrefs are blocks of data with fixed type and rank.
|
||||
class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
|
||||
"$_self.cast<MemRefType>().getElementType()", "memref">;
|
||||
|
||||
def AnyMemRef : MemRefOf<[AnyType]>;
|
||||
|
||||
// Memref declarations handle any memref, independent of rank, size, (static or
|
||||
// dynamic), layout, or memory space.
|
||||
def I1MemRef : MemRefOf<[I1]>;
|
||||
def I8MemRef : MemRefOf<[I8]>;
|
||||
def I16MemRef : MemRefOf<[I16]>;
|
||||
def I32MemRef : MemRefOf<[I32]>;
|
||||
def I64MemRef : MemRefOf<[I64]>;
|
||||
|
||||
def BF16MemRef : MemRefOf<[BF16]>;
|
||||
def F16MemRef : MemRefOf<[F16]>;
|
||||
def F32MemRef : MemRefOf<[F32]>;
|
||||
def F64MemRef : MemRefOf<[F64]>;
|
||||
|
||||
// This represents a generic tuple without any constraints on element type.
|
||||
def AnyTuple : Type<IsTupleTypePred, "tuple">;
|
||||
|
||||
// TODO(b/132952417) Make this accept a list of types like the classes above.
|
||||
// A Tuple that only holds elements of a certain type. This cannot inherit from
|
||||
// ContainerType because tuples do not always have a single element type that
|
||||
// could be retrieved with elementTypeCall.
|
||||
class TypedTuple<Type t> :
|
||||
class TupleOf<Type t> :
|
||||
Type<And<[
|
||||
Tuple.predicate,
|
||||
IsTupleTypePred,
|
||||
Concat<
|
||||
[{
|
||||
llvm::all_of(
|
||||
|
@ -398,26 +421,6 @@ class TypedTuple<Type t> :
|
|||
"; })">
|
||||
]>, "tuple">;
|
||||
|
||||
// Memref type.
|
||||
|
||||
// Memrefs are blocks of data with fixed type and rank.
|
||||
class MemRef<Type t>
|
||||
: ContainerType<t, IsMemRefTypePred,
|
||||
"$_self.cast<MemRefType>().getElementType()", "memref">;
|
||||
|
||||
// Memref declarations handle any memref, independent of rank, size, (static or
|
||||
// dynamic), layout, or memory space.
|
||||
def I1MemRef : MemRef<I1>;
|
||||
def I8MemRef : MemRef<I8>;
|
||||
def I16MemRef : MemRef<I16>;
|
||||
def I32MemRef : MemRef<I32>;
|
||||
def I64MemRef : MemRef<I64>;
|
||||
|
||||
def BF16MemRef : MemRef<BF16>;
|
||||
def F16MemRef : MemRef<F16>;
|
||||
def F32MemRef : MemRef<F32>;
|
||||
def F64MemRef : MemRef<F64>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common type constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -425,12 +428,12 @@ def F64MemRef : MemRef<F64>;
|
|||
// Type constraint for integer-like types: integers, indices, vectors of
|
||||
// integers, tensors of integers.
|
||||
def IntegerLike : TypeConstraint<Or<[Integer.predicate, Index.predicate,
|
||||
Vector<Integer>.predicate, Tensor<Integer>.predicate]>,
|
||||
VectorOf<[Integer]>.predicate, TensorOf<[Integer]>.predicate]>,
|
||||
"integer-like">;
|
||||
|
||||
// Type constraint for float-like types: floats, vectors or tensors thereof.
|
||||
def FloatLike : TypeConstraint<Or<[Float.predicate,
|
||||
Vector<Float>.predicate, Tensor<Float>.predicate]>,
|
||||
VectorOf<[Float]>.predicate, TensorOf<[Float]>.predicate]>,
|
||||
"floating-point-like">;
|
||||
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ def AllocOp : Std_Op<"alloc"> {
|
|||
}];
|
||||
|
||||
let arguments = (ins Variadic<Index>:$value);
|
||||
let results = (outs MemRef<AnyType>);
|
||||
let results = (outs AnyMemRef);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, MemRefType memrefType", [{
|
||||
|
@ -303,7 +303,7 @@ def DeallocOp : Std_Op<"dealloc"> {
|
|||
dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
|
||||
}];
|
||||
|
||||
let arguments = (ins MemRef<AnyType>:$memref);
|
||||
let arguments = (ins AnyMemRef:$memref);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
@ -318,7 +318,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
|||
%1 = dim %0, 2 : tensor<?x?x?xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[MemRef<AnyType>, AnyTensor],
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
|
||||
"any tensor or memref type">:$memrefOrTensor,
|
||||
APIntAttr:$index);
|
||||
let results = (outs Index);
|
||||
|
@ -410,8 +410,8 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
|||
%3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins MemRef<AnyType>);
|
||||
let results = (outs MemRef<AnyType>);
|
||||
let arguments = (ins AnyMemRef);
|
||||
let results = (outs AnyMemRef);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
|
|
|
@ -112,7 +112,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
|
|||
// CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
|
||||
|
||||
def OpK : NS_Op<"op_for_AnyTensorOf", []> {
|
||||
let arguments = (ins AnyTensorOf<[F32, I32]>:$x);
|
||||
let arguments = (ins TensorOf<[F32, I32]>:$x);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpK::verify
|
||||
|
|
Loading…
Reference in New Issue