Clean up tablegen vector and tensor types

There was a weird mix of names, styles, and inheritance here. I think this makes it cleaner and more consistent. We can also have a more principled and far-reaching refactor of some of this naming, but this seems like a good improvement regardless

--

PiperOrigin-RevId: 248827005
This commit is contained in:
Geoffrey Martin-Noble 2019-05-17 19:45:45 -07:00 committed by Mehdi Amini
parent 8e5bfb85c4
commit b5ecbb7fd6
7 changed files with 64 additions and 86 deletions

View File

@ -28,8 +28,8 @@
class quant_TypedPrimitiveOrContainer<Type etype> :
Type<AnyOf<[etype.predicate,
TypedTensor<etype>.predicate,
TypedVector<etype>.predicate]>,
Tensor<etype>.predicate,
Vector<etype>.predicate]>,
"primitive/tensor/vector of " # etype.description>;
// An implementation of QuantizedType.

View File

@ -195,13 +195,12 @@ def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
// Whether a type is a ShapedType.
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
// For a ShapedType, verify that it has a static shape.
def HasStaticShapePred : CPred<"$_self.cast<ShapedType>().hasStaticShape()">;
// Whether a type is a TupleType.
def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">;
// For a TensorType, verify that it is a statically shaped tensor.
def IsStaticShapeTensorTypePred :
CPred<"$_self.cast<TensorType>().hasStaticShape()">;
//===----------------------------------------------------------------------===//
// Dialect definitions
//===----------------------------------------------------------------------===//
@ -230,7 +229,7 @@ class Type<Pred condition, string descr = ""> :
// more than one variadic operand/result, and that operand/result must be the
// last one in the operand/result list.
class Variadic<Type type, string descr = "">
// TODO: support variadic type conditions
// TODO(b/132908002): support variadic type conditions
: TypeConstraint<CPred<"true">, descr> {
Type baseType = type;
}
@ -254,7 +253,7 @@ def AnyType : Type<CPred<"true">, "any type">;
def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string description> : Type<
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
// Satisfy any of the allowed type's condition
AnyOf<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
!if(!eq(description, ""),
@ -330,70 +329,49 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
code getElementTypeCall = elementTypeCall;
}
class ShapedContainerType<Type etype, Pred containerPred, string descr> :
ContainerType<etype, containerPred,
"$_self.cast<ShapedType>().getElementType()", descr>;
// Vector types.
class TypedVector<Type t> : ContainerType<t, IsVectorTypePred,
"$_self.cast<VectorType>().getElementType()", "vector">;
class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
IsVectorTypePred,
// Match dims. Construct an ArrayRef with the elements of `dims` by folding
// over the list.
CPred<"$_self.cast<VectorType>().getShape() == ArrayRef{{" #
Stringify<dims>.result # "}">]>,
"$_self.cast<VectorType>().getElementType()",
"vector"> {
list<int> dimensions = dims;
}
class Vector<Type t> : ShapedContainerType<t, IsVectorTypePred, "vector">;
// Tensor type.
def AnyVector : Vector<AnyType>;
// This represents a generic tensor without constraints on elemental type,
// rank, size. As there is no constraint on elemental type, derive from Type
// directly instead of ContainerType.
def Tensor : Type<IsTensorTypePred, "tensor">;
// Tensor types.
// A tensor with static shape but no other constraints. Note: as
// Tensor is a def this doesn't derive from it, but reuses the predicate
// that must hold for it to be a tensor.
def StaticShapeTensor
: Type<AllOf<[Tensor.predicate, IsStaticShapeTensorTypePred]>,
"statically shaped tensor">;
class Tensor<Type t> : ShapedContainerType<t, IsTensorTypePred, "tensor">;
// For typed tensors.
class TypedTensor<Type t>
: ContainerType<t, Tensor.predicate,
"$_self.cast<TensorType>().getElementType()",
"tensor">;
def AnyTensor : Tensor<AnyType>;
class TypedStaticShapeTensor<Type t>
: Type<AllOf<[ TypedTensor<t>.predicate, IsStaticShapeTensorTypePred ]>,
"statically shaped tensor">;
def I1Tensor : TypedTensor<I1>;
def I8Tensor : TypedTensor<I8>;
def I16Tensor : TypedTensor<I16>;
def I32Tensor : TypedTensor<I32>;
def I64Tensor : TypedTensor<I64>;
def BF16Tensor : TypedTensor<BF16>;
def F16Tensor : TypedTensor<F16>;
def F32Tensor : TypedTensor<F32>;
def F64Tensor : TypedTensor<F64>;
// Any tensor type whose element type is from the given
// `allowedTypes` list
// Any tensor type whose element type is from the given `allowedTypes` list
class AnyTensorOf<list<Type> allowedTypes, string elementDescription = ""> :
TypedTensor<AnyTypeOf<allowedTypes, elementDescription>>;
Tensor<AnyTypeOf<allowedTypes, elementDescription>>;
def VectorOrTensor :
AnyTypeOf<[TypedVector<AnyType>, Tensor], "vector or tensor">;
// TODO(b/130807343) Fix description to contain element information.
class StaticShapeTensor<Type t>
: Type<AllOf<[ Tensor<t>.predicate, HasStaticShapePred ]>,
"statically shaped tensor">;
def AnyStaticShapeTensor : StaticShapeTensor<AnyType>;
def I1Tensor : Tensor<I1>;
def I8Tensor : Tensor<I8>;
def I16Tensor : Tensor<I16>;
def I32Tensor : Tensor<I32>;
def I64Tensor : Tensor<I64>;
def BF16Tensor : Tensor<BF16>;
def F16Tensor : Tensor<F16>;
def F32Tensor : Tensor<F32>;
def F64Tensor : Tensor<F64>;
// This represents a generic tuple without any constraints on elemental type,
// ranks, or size. As Tuples can contain tensors, vectors, or scalar values
// there is not only a single elemental type.
def Tuple : Type<IsTupleTypePred, "tuple">;
// A Tuple that only holds elements of a certain type. This cannot inherit from
// ContainerType because tuples do not always have a single element type that
// could be retrieved with elementTypeCall.
@ -438,12 +416,12 @@ def F64MemRef : MemRef<F64>;
// Type constraint for integer-like types: integers, indices, vectors of
// integers, tensors of integers.
def IntegerLike : TypeConstraint<AnyOf<[Integer.predicate, Index.predicate,
TypedVector<Integer>.predicate, TypedTensor<Integer>.predicate]>,
Vector<Integer>.predicate, Tensor<Integer>.predicate]>,
"integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
TypedVector<Float>.predicate, TypedTensor<Float>.predicate]>,
Vector<Float>.predicate, Tensor<Float>.predicate]>,
"floating-point-like">;
@ -1037,7 +1015,7 @@ def addBenefit;
// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1),
// [(OneResultOp2:$op2 $arg0, $arg1),
// (OneResultOp3 $op2 (OneResultOp4))],
// [(IsStaticShapeTensorTypePred $op1)]>;
// [(HasStaticShapePred $op1)]>;
// ```
//
// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to

View File

@ -317,7 +317,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
%1 = dim %0, 2 : tensor<?x?x?xf32>
}];
let arguments = (ins AnyTypeOf<[MemRef<AnyType>, Tensor],
let arguments = (ins AnyTypeOf<[MemRef<AnyType>, AnyTensor],
"any tensor or memref type">:$memrefOrTensor,
APIntAttr:$index);
let results = (outs Index);
@ -366,7 +366,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
%0 = extract_element %0[%1, %2] : vector<4x4xi32>
}];
let arguments = (ins VectorOrTensor:$aggregate,
let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate,
Variadic<Index>:$indices);
let results = (outs AnyType);
@ -498,8 +498,8 @@ def TensorCastOp : CastOp<"tensor_cast"> {
%2 = tensor_cast %1 : tensor<??f32> to tensor<?x?xf32>
}];
let arguments = (ins Tensor);
let results = (outs Tensor);
let arguments = (ins AnyTensor);
let results = (outs AnyTensor);
let extraClassDeclaration = [{
/// Return true if `a` and `b` are valid operand and result pairs for

View File

@ -37,7 +37,7 @@ def OpB : NS_Op<"one_variadic_operand_op", []> {
// CHECK: tblgen_state->addOperands(input);
def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<Tensor>:$input1, Variadic<Tensor>:$input2);
let arguments = (ins Variadic<AnyTensor>:$input1, Variadic<AnyTensor>:$input2);
}
// CHECK-LABEL: Operation::operand_range OpC::input1()
@ -55,7 +55,7 @@ def OpC : NS_Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
// CHECK-NEXT: tblgen_state->addOperands(input2);
def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<Tensor>:$input1, Tensor:$input2, Variadic<Tensor>:$input3);
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
}
// CHECK-LABEL: Operation::operand_range OpD::input1()
@ -79,7 +79,7 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-NEXT: tblgen_state->addOperands(input3);
def OpE : NS_Op<"one_variadic_among_multi_normal_inputs_op", []> {
let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic<Tensor>:$input3, Tensor:$input4, Tensor:$input5);
let arguments = (ins AnyTensor:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3, AnyTensor:$input4, AnyTensor:$input5);
}
// CHECK-LABEL: Value *OpE::input1()

View File

@ -48,7 +48,7 @@ def OpC : NS_Op<"three_normal_result_op", []> {
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
let results = (outs Tensor:$y);
let results = (outs AnyTensor:$y);
}
// CHECK-LABEL: OpD definitions
@ -57,7 +57,7 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr);
let results = (outs Tensor:$y);
let results = (outs AnyTensor:$y);
}
// CHECK-LABEL: OpE definitions
@ -94,7 +94,7 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<Tensor>:$output1, Variadic<Tensor>:$output2);
let results = (outs Variadic<AnyTensor>:$output1, Variadic<AnyTensor>:$output2);
}
// CHECK-LABEL: Operation::result_range OpH::output1()
@ -113,7 +113,7 @@ def OpH : NS_Op<"all_variadic_results_op", [SameVariadicResultSize]> {
// CHECK-NEXT: tblgen_state->addTypes(output2);
def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<Tensor>:$output1, Tensor:$output2, Variadic<Tensor>:$output3);
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
}
// CHECK-LABEL: Operation::result_range OpI::output1()
@ -137,7 +137,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
// CHECK-NEXT: tblgen_state->addTypes(output3);
def OpJ : NS_Op<"one_variadic_among_multi_normal_results_op", []> {
let results = (outs Tensor:$output1, Tensor:$output2, Variadic<Tensor>:$output3, Tensor:$output4, Tensor:$output5);
let results = (outs AnyTensor:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3, AnyTensor:$output4, AnyTensor:$output5);
}
// CHECK-LABEL: Value *OpJ::output1()
@ -159,8 +159,8 @@ def OpJ : NS_Op<"one_variadic_among_multi_normal_results_op", []> {
// pack to set result type
// ---
def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> {
let arguments = (ins Variadic<Tensor>:$input);
let results = (outs Tensor:$result);
let arguments = (ins Variadic<AnyTensor>:$input);
let results = (outs AnyTensor:$result);
}
// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)

View File

@ -29,7 +29,7 @@ def OpB : NS_Op<"op_for_AllOf_PredOpTrait", [
def OpC : NS_Op<"op_for_TCopVTEtIs", [
PredOpTrait<"first operand has i32 element type",
TCopVTEtIs<0, I32>>]> {
let arguments = (ins Tensor:$x);
let arguments = (ins AnyTensor:$x);
}
// CHECK-LABEL: OpC::verify
@ -40,7 +40,7 @@ def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [
PredOpTrait<"first operand is a vector or tensor with the same "
"elemental type as itself",
TCopVTEtIsSameAs<0, 0>>]> {
let arguments = (ins Tensor:$x);
let arguments = (ins AnyTensor:$x);
}
// CHECK-LABEL: OpD::verify
@ -52,8 +52,8 @@ def OpE : NS_Op<"op_for_TCresVTEtIsSameAsOp", [
PredOpTrait<"first operand is a vector or tensor with the same "
"elemental type as first result",
TCresVTEtIsSameAsOp<0, 0>>]> {
let arguments = (ins Tensor:$x);
let results = (outs Tensor:$y);
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
// CHECK-LABEL: OpE::verify
@ -97,11 +97,11 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
PredOpTrait<"operands indexed at 0, 2, 3 should all have "
"the same type", TCopVTEtAreSameAt<[0, 2, 3]>>]> {
let arguments = (ins
Tensor:$a,
Tensor:$b,
Tensor:$c,
Tensor:$d,
Tensor:$e
AnyTensor:$a,
AnyTensor:$b,
AnyTensor:$c,
AnyTensor:$d,
AnyTensor:$e
);
}
@ -116,4 +116,4 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
}
// CHECK-LABEL: OpK::verify
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<TensorType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<TensorType>().getElementType().isInteger(32))))))
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isInteger(32))))))

View File

@ -12,8 +12,8 @@ class X_Op<string mnemonic, list<OpTrait> traits = []> :
Op<X_Dialect, mnemonic, traits>;
def X_AddOp : X_Op<"add">,
Arguments<(ins Tensor:$A, Tensor:$B)>,
Results<(outs Tensor: $C)> {
Arguments<(ins AnyTensor:$A, AnyTensor:$B)>,
Results<(outs AnyTensor: $C)> {
// TODO: extract referenceImplementation to Op.
code referenceImplementation = [{
auto ivs = IndexHandle::makeIndexHandles(view_A.rank());