forked from OSchip/llvm-project
Clean up tablegen vector and tensor types
There was a weird mix of names, styles, and inheritance here. I think this makes it cleaner and more consistent. We can also have a more principled and far-reaching refactor of some of this naming, but this seems like a good improvement regardless -- PiperOrigin-RevId: 248827005
This commit is contained in:
parent
8e5bfb85c4
commit
b5ecbb7fd6
|
@ -28,8 +28,8 @@
|
|||
|
||||
class quant_TypedPrimitiveOrContainer<Type etype> :
|
||||
Type<AnyOf<[etype.predicate,
|
||||
TypedTensor<etype>.predicate,
|
||||
TypedVector<etype>.predicate]>,
|
||||
Tensor<etype>.predicate,
|
||||
Vector<etype>.predicate]>,
|
||||
"primitive/tensor/vector of " # etype.description>;
|
||||
|
||||
// An implementation of QuantizedType.
|
||||
|
|
|
@ -195,13 +195,12 @@ def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
|||
// Whether a type is a ShapedType.
|
||||
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
|
||||
|
||||
// For a ShapedType, verify that it has a static shape.
|
||||
def HasStaticShapePred : CPred<"$_self.cast<ShapedType>().hasStaticShape()">;
|
||||
|
||||
// Whether a type is a TupleType.
|
||||
def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">;
|
||||
|
||||
// For a TensorType, verify that it is a statically shaped tensor.
|
||||
def IsStaticShapeTensorTypePred :
|
||||
CPred<"$_self.cast<TensorType>().hasStaticShape()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -230,7 +229,7 @@ class Type<Pred condition, string descr = ""> :
|
|||
// more than one variadic operand/result, and that operand/result must be the
|
||||
// last one in the operand/result list.
|
||||
class Variadic<Type type, string descr = "">
|
||||
// TODO: support variadic type conditions
|
||||
// TODO(b/132908002): support variadic type conditions
|
||||
: TypeConstraint<CPred<"true">, descr> {
|
||||
Type baseType = type;
|
||||
}
|
||||
|
@ -254,7 +253,7 @@ def AnyType : Type<CPred<"true">, "any type">;
|
|||
def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
|
||||
|
||||
// Any type from the given list
|
||||
class AnyTypeOf<list<Type> allowedTypes, string description> : Type<
|
||||
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
|
||||
// Satisfy any of the allowed type's condition
|
||||
AnyOf<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
|
||||
!if(!eq(description, ""),
|
||||
|
@ -330,70 +329,49 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
|||
code getElementTypeCall = elementTypeCall;
|
||||
}
|
||||
|
||||
class ShapedContainerType<Type etype, Pred containerPred, string descr> :
|
||||
ContainerType<etype, containerPred,
|
||||
"$_self.cast<ShapedType>().getElementType()", descr>;
|
||||
|
||||
// Vector types.
|
||||
class TypedVector<Type t> : ContainerType<t, IsVectorTypePred,
|
||||
"$_self.cast<VectorType>().getElementType()", "vector">;
|
||||
|
||||
class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
|
||||
IsVectorTypePred,
|
||||
// Match dims. Construct an ArrayRef with the elements of `dims` by folding
|
||||
// over the list.
|
||||
CPred<"$_self.cast<VectorType>().getShape() == ArrayRef{{" #
|
||||
Stringify<dims>.result # "}">]>,
|
||||
"$_self.cast<VectorType>().getElementType()",
|
||||
"vector"> {
|
||||
list<int> dimensions = dims;
|
||||
}
|
||||
class Vector<Type t> : ShapedContainerType<t, IsVectorTypePred, "vector">;
|
||||
|
||||
// Tensor type.
|
||||
def AnyVector : Vector<AnyType>;
|
||||
|
||||
// This represents a generic tensor without constraints on elemental type,
|
||||
// rank, size. As there is no constraint on elemental type, derive from Type
|
||||
// directly instead of ContainerType.
|
||||
def Tensor : Type<IsTensorTypePred, "tensor">;
|
||||
// Tensor types.
|
||||
|
||||
// A tensor with static shape but no other constraints. Note: as
|
||||
// Tensor is a def this doesn't derive from it, but reuses the predicate
|
||||
// that must hold for it to be a tensor.
|
||||
def StaticShapeTensor
|
||||
: Type<AllOf<[Tensor.predicate, IsStaticShapeTensorTypePred]>,
|
||||
"statically shaped tensor">;
|
||||
class Tensor<Type t> : ShapedContainerType<t, IsTensorTypePred, "tensor">;
|
||||
|
||||
// For typed tensors.
|
||||
class TypedTensor<Type t>
|
||||
: ContainerType<t, Tensor.predicate,
|
||||
"$_self.cast<TensorType>().getElementType()",
|
||||
"tensor">;
|
||||
def AnyTensor : Tensor<AnyType>;
|
||||
|
||||
class TypedStaticShapeTensor<Type t>
|
||||
: Type<AllOf<[ TypedTensor<t>.predicate, IsStaticShapeTensorTypePred ]>,
|
||||
"statically shaped tensor">;
|
||||
|
||||
def I1Tensor : TypedTensor<I1>;
|
||||
def I8Tensor : TypedTensor<I8>;
|
||||
def I16Tensor : TypedTensor<I16>;
|
||||
def I32Tensor : TypedTensor<I32>;
|
||||
def I64Tensor : TypedTensor<I64>;
|
||||
|
||||
def BF16Tensor : TypedTensor<BF16>;
|
||||
def F16Tensor : TypedTensor<F16>;
|
||||
def F32Tensor : TypedTensor<F32>;
|
||||
def F64Tensor : TypedTensor<F64>;
|
||||
|
||||
// Any tensor type whose element type is from the given
|
||||
// `allowedTypes` list
|
||||
// Any tensor type whose element type is from the given `allowedTypes` list
|
||||
class AnyTensorOf<list<Type> allowedTypes, string elementDescription = ""> :
|
||||
TypedTensor<AnyTypeOf<allowedTypes, elementDescription>>;
|
||||
Tensor<AnyTypeOf<allowedTypes, elementDescription>>;
|
||||
|
||||
def VectorOrTensor :
|
||||
AnyTypeOf<[TypedVector<AnyType>, Tensor], "vector or tensor">;
|
||||
// TODO(b/130807343) Fix description to contain element information.
|
||||
class StaticShapeTensor<Type t>
|
||||
: Type<AllOf<[ Tensor<t>.predicate, HasStaticShapePred ]>,
|
||||
"statically shaped tensor">;
|
||||
|
||||
def AnyStaticShapeTensor : StaticShapeTensor<AnyType>;
|
||||
|
||||
def I1Tensor : Tensor<I1>;
|
||||
def I8Tensor : Tensor<I8>;
|
||||
def I16Tensor : Tensor<I16>;
|
||||
def I32Tensor : Tensor<I32>;
|
||||
def I64Tensor : Tensor<I64>;
|
||||
|
||||
def BF16Tensor : Tensor<BF16>;
|
||||
def F16Tensor : Tensor<F16>;
|
||||
def F32Tensor : Tensor<F32>;
|
||||
def F64Tensor : Tensor<F64>;
|
||||
|
||||
// This represents a generic tuple without any constraints on elemental type,
|
||||
// ranks, or size. As Tuples can contain tensors, vectors, or scalar values
|
||||
// there is not only a single elemental type.
|
||||
def Tuple : Type<IsTupleTypePred, "tuple">;
|
||||
|
||||
|
||||
// A Tuple that only holds elements of a certain type. This cannot inherit from
|
||||
// ContainerType because tuples do not always have a single element type that
|
||||
// could be retrieved with elementTypeCall.
|
||||
|
@ -438,12 +416,12 @@ def F64MemRef : MemRef<F64>;
|
|||
// Type constraint for integer-like types: integers, indices, vectors of
|
||||
// integers, tensors of integers.
|
||||
def IntegerLike : TypeConstraint<AnyOf<[Integer.predicate, Index.predicate,
|
||||
TypedVector<Integer>.predicate, TypedTensor<Integer>.predicate]>,
|
||||
Vector<Integer>.predicate, Tensor<Integer>.predicate]>,
|
||||
"integer-like">;
|
||||
|
||||
// Type constraint for float-like types: floats, vectors or tensors thereof.
|
||||
def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
|
||||
TypedVector<Float>.predicate, TypedTensor<Float>.predicate]>,
|
||||
Vector<Float>.predicate, Tensor<Float>.predicate]>,
|
||||
"floating-point-like">;
|
||||
|
||||
|
||||
|
@ -1037,7 +1015,7 @@ def addBenefit;
|
|||
// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1),
|
||||
// [(OneResultOp2:$op2 $arg0, $arg1),
|
||||
// (OneResultOp3 $op2 (OneResultOp4))],
|
||||
// [(IsStaticShapeTensorTypePred $op1)]>;
|
||||
// [(HasStaticShapePred $op1)]>;
|
||||
// ```
|
||||
//
|
||||
// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to
|
||||
|
|
|
@ -317,7 +317,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
|
|||
%1 = dim %0, 2 : tensor<?x?x?xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[MemRef<AnyType>, Tensor],
|
||||
let arguments = (ins AnyTypeOf<[MemRef<AnyType>, AnyTensor],
|
||||
"any tensor or memref type">:$memrefOrTensor,
|
||||
APIntAttr:$index);
|
||||
let results = (outs Index);
|
||||
|
@ -366,7 +366,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
|||
%0 = extract_element %0[%1, %2] : vector<4x4xi32>
|
||||
}];
|
||||
|
||||
let arguments = (ins VectorOrTensor:$aggregate,
|
||||
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyType);
|
||||
|
||||
|
@ -498,8 +498,8 @@ def TensorCastOp : CastOp<"tensor_cast"> {
|
|||
%2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
|
||||
}];
|
||||
|
||||
let arguments = (ins Tensor);
|
||||
let results = (outs Tensor);
|
||||
let arguments = (ins AnyTensor);
|
||||
let results = (outs AnyTensor);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return true if `a` and `b` are valid operand and result pairs for
|
||||
|
|
|
@ -37,7 +37,7 @@ def OpB : NS_Op<"one_variadic_operand_op", []> {
|
|||
// CHECK: tblgen_state->addOperands(input);
|
||||
|
||||
def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input1, Variadic<Tensor>:$input2);
|
||||
let arguments = (ins Variadic<AnyTensor>:$input1, Variadic<AnyTensor>:$input2);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpC::input1()
|
||||
|
@ -55,7 +55,7 @@ def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
|
|||
// CHECK-NEXT: tblgen_state->addOperands(input2);
|
||||
|
||||
def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input1, Tensor:$input2, Variadic<Tensor>:$input3);
|
||||
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpD::input1()
|
||||
|
@ -79,7 +79,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
|
|||
// CHECK-NEXT: tblgen_state->addOperands(input3);
|
||||
|
||||
def OpE : NS_Op<"one_variadic_among_multi_normal_inputs_op", []> {
|
||||
let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic<Tensor>:$input3, Tensor:$input4, Tensor:$input5);
|
||||
let arguments = (ins AnyTensor:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3, AnyTensor:$input4, AnyTensor:$input5);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Value *OpE::input1()
|
||||
|
|
|
@ -48,7 +48,7 @@ def OpC : NS_Op<"three_normal_result_op", []> {
|
|||
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
|
||||
def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
|
||||
let results = (outs Tensor:$y);
|
||||
let results = (outs AnyTensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpD definitions
|
||||
|
@ -57,7 +57,7 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
|||
|
||||
def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
let arguments = (ins I32:$x, F32Attr:$attr);
|
||||
let results = (outs Tensor:$y);
|
||||
let results = (outs AnyTensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpE definitions
|
||||
|
@ -94,7 +94,7 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
|
|||
|
||||
|
||||
def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> {
|
||||
let results = (outs Variadic<Tensor>:$output1, Variadic<Tensor>:$output2);
|
||||
let results = (outs Variadic<AnyTensor>:$output1, Variadic<AnyTensor>:$output2);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpH::output1()
|
||||
|
@ -113,7 +113,7 @@ def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> {
|
|||
// CHECK-NEXT: tblgen_state->addTypes(output2);
|
||||
|
||||
def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
|
||||
let results = (outs Variadic<Tensor>:$output1, Tensor:$output2, Variadic<Tensor>:$output3);
|
||||
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpI::output1()
|
||||
|
@ -137,7 +137,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
|
|||
// CHECK-NEXT: tblgen_state->addTypes(output3);
|
||||
|
||||
def OpJ : NS_Op<"one_variadic_among_multi_normal_results_op", []> {
|
||||
let results = (outs Tensor:$output1, Tensor:$output2, Variadic<Tensor>:$output3, Tensor:$output4, Tensor:$output5);
|
||||
let results = (outs AnyTensor:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3, AnyTensor:$output4, AnyTensor:$output5);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Value *OpJ::output1()
|
||||
|
@ -159,8 +159,8 @@ def OpJ : NS_Op<"one_variadic_among_multi_normal_results_op", []> {
|
|||
// pack to set result type
|
||||
// ---
|
||||
def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input);
|
||||
let results = (outs Tensor:$result);
|
||||
let arguments = (ins Variadic<AnyTensor>:$input);
|
||||
let results = (outs AnyTensor:$result);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
|
||||
|
|
|
@ -29,7 +29,7 @@ def OpB : NS_Op<"op_for_AllOf_PredOpTrait", [
|
|||
def OpC : NS_Op<"op_for_TCopVTEtIs", [
|
||||
PredOpTrait<"first operand has i32 element type",
|
||||
TCopVTEtIs<0, I32>>]> {
|
||||
let arguments = (ins Tensor:$x);
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpC::verify
|
||||
|
@ -40,7 +40,7 @@ def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [
|
|||
PredOpTrait<"first operand is a vector or tensor with the same "
|
||||
"elemental type as itself",
|
||||
TCopVTEtIsSameAs<0, 0>>]> {
|
||||
let arguments = (ins Tensor:$x);
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpD::verify
|
||||
|
@ -52,8 +52,8 @@ def OpE : NS_Op<"op_for_TCresVTEtIsSameAsOp", [
|
|||
PredOpTrait<"first operand is a vector or tensor with the same "
|
||||
"elemental type as first result",
|
||||
TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let arguments = (ins Tensor:$x);
|
||||
let results = (outs Tensor:$y);
|
||||
let arguments = (ins AnyTensor:$x);
|
||||
let results = (outs AnyTensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpE::verify
|
||||
|
@ -97,11 +97,11 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
|
|||
PredOpTrait<"operands indexed at 0, 2, 3 should all have "
|
||||
"the same type", TCopVTEtAreSameAt<[0, 2, 3]>>]> {
|
||||
let arguments = (ins
|
||||
Tensor:$a,
|
||||
Tensor:$b,
|
||||
Tensor:$c,
|
||||
Tensor:$d,
|
||||
Tensor:$e
|
||||
AnyTensor:$a,
|
||||
AnyTensor:$b,
|
||||
AnyTensor:$c,
|
||||
AnyTensor:$d,
|
||||
AnyTensor:$e
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -116,4 +116,4 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: OpK::verify
|
||||
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<TensorType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<TensorType>().getElementType().isInteger(32))))))
|
||||
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isInteger(32))))))
|
||||
|
|
|
@ -12,8 +12,8 @@ class X_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
Op<X_Dialect, mnemonic, traits>;
|
||||
|
||||
def X_AddOp : X_Op<"add">,
|
||||
Arguments<(ins Tensor:$A, Tensor:$B)>,
|
||||
Results<(outs Tensor: $C)> {
|
||||
Arguments<(ins AnyTensor:$A, AnyTensor:$B)>,
|
||||
Results<(outs AnyTensor: $C)> {
|
||||
// TODO: extract referenceImplementation to Op.
|
||||
code referenceImplementation = [{
|
||||
auto ivs = IndexHandle::makeIndexHandles(view_A.rank());
|
||||
|
|
Loading…
Reference in New Issue