diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 8305bbbc2452..a8e45dff3885 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -32,7 +32,7 @@ class IntegerSet; class Location; class MLIRContext; class Type; -class VectorOrTensorType; +class ShapedType; namespace detail { @@ -417,7 +417,7 @@ class ElementsAttr : public Attribute { public: using Attribute::Attribute; - VectorOrTensorType getType() const; + ShapedType getType() const; /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. @@ -439,7 +439,7 @@ public: using Base::Base; using ValueType = Attribute; - static SplatElementsAttr get(VectorOrTensorType type, Attribute elt); + static SplatElementsAttr get(ShapedType type, Attribute elt); Attribute getValue() const; /// Method for support type inquiry through isa, cast and dyn_cast. @@ -457,12 +457,11 @@ public: /// It assumes the elements in the input array have been truncated to the bits /// width specified by the element type. - static DenseElementsAttr get(VectorOrTensorType type, ArrayRef data); + static DenseElementsAttr get(ShapedType type, ArrayRef data); // Constructs a dense elements attribute from an array of element values. Each // element attribute value is expected to be an element of 'type'. - static DenseElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseElementsAttr get(ShapedType type, ArrayRef values); /// Returns the number of elements held by this attribute. size_t size() const; @@ -542,7 +541,7 @@ protected: // Constructs a dense elements attribute from an array of raw APInt values. // Each APInt value is expected to have the same bitwidth as the element type // of 'type'. - static DenseElementsAttr get(VectorOrTensorType type, ArrayRef values); + static DenseElementsAttr get(ShapedType type, ArrayRef values); }; /// An attribute that represents a reference to a dense integer vector or tensor @@ -562,14 +561,12 @@ public: /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. - static DenseIntElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseIntElementsAttr get(ShapedType type, ArrayRef values); /// Constructs a dense integer elements attribute from an array of integer /// values. Each value is expected to be within the bitwidth of the element /// type of 'type'. - static DenseIntElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseIntElementsAttr get(ShapedType type, ArrayRef values); /// Gets the integer value of each of the dense elements. void getValues(SmallVectorImpl &values) const; @@ -609,8 +606,7 @@ public: // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. - static DenseFPElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseFPElementsAttr get(ShapedType type, ArrayRef values); /// Gets the float value of each of the dense elements. void getValues(SmallVectorImpl &values) const; @@ -637,7 +633,7 @@ public: using Base::Base; using ValueType = StringRef; - static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type, + static OpaqueElementsAttr get(Dialect *dialect, ShapedType type, StringRef bytes); StringRef getValue() const; @@ -684,8 +680,7 @@ class SparseElementsAttr public: using Base::Base; - static SparseElementsAttr get(VectorOrTensorType type, - DenseIntElementsAttr indices, + static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values); DenseIntElementsAttr getIndices() const; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 7f182e882db0..d852a804723c 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -113,17 +113,16 @@ public: IntegerSetAttr getIntegerSetAttr(IntegerSet set); TypeAttr getTypeAttr(Type type); FunctionAttr getFunctionAttr(Function *value); - ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt); - ElementsAttr getDenseElementsAttr(VectorOrTensorType type, - ArrayRef data); - ElementsAttr getDenseElementsAttr(VectorOrTensorType type, + ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt); + ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef data); + ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef values); - ElementsAttr getDenseIntElementsAttr(VectorOrTensorType type, + ElementsAttr getDenseIntElementsAttr(ShapedType type, ArrayRef values); - ElementsAttr getSparseElementsAttr(VectorOrTensorType type, + ElementsAttr getSparseElementsAttr(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values); - ElementsAttr getOpaqueElementsAttr(Dialect *dialect, VectorOrTensorType type, + ElementsAttr getOpaqueElementsAttr(Dialect *dialect, ShapedType type, StringRef bytes); // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 32-/64-bit float types, and vector or ranked diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 3e337b298192..9a7cb2dfb1a8 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -99,7 +99,7 @@ struct constant_int_op_binder { if (type.isa()) { return attr_value_binder(bind_value).match(attr); } - if (type.isa()) { + if (type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) .match(splatAttr.getValue()); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 1463da28f978..ecd61db0e1d1 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -189,15 +189,15 @@ def IsVectorTypePred : CPred<"$_self.isa()">; // Whether a type is a TensorType. def IsTensorTypePred : CPred<"$_self.isa()">; -// Whether a type is a VectorOrTensorType. -def IsVectorOrTensorTypePred : CPred<"$_self.isa()">; +// Whether a type is a MemRefType. +def IsMemRefTypePred : CPred<"$_self.isa()">; + +// Whether a type is a ShapedType. +def IsShapedTypePred : CPred<"$_self.isa()">; // Whether a type is a TupleType. def IsTupleTypePred : CPred<"$_self.isa()">; -// Whether a type is a MemRefType. -def IsMemRefTypePred : CPred<"$_self.isa()">; - // For a TensorType, verify that it is a statically shaped tensor. def IsStaticShapeTensorTypePred : CPred<"$_self.cast().hasStaticShape()">; @@ -345,7 +345,7 @@ class Vector dims> : ContainerType dimensions = dims; } -def VectorOrTensor : Type; +def VectorOrTensor : Type; // Tensor type. @@ -953,50 +953,49 @@ class Results { // Common op type constraints //===----------------------------------------------------------------------===// -// Type Constraint operand `idx`'s Vector or Tensor Element type is `type`. +// Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : AllOf<[ CPred<"$_op.getNumOperands() > " # idx>, SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # idx # - ")->getType().cast().getElementType()", + ")->getType().cast().getElementType()", type.predicate>]>; // Predicate to verify that the i'th operand and the j'th operand have the same // elemental type. -// Type Constraint operand `i`'s Vector or Tensor Element type is Same As -// operand `j`'s element type. +// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element +// type. class TCopVTEtIsSameAs : AllOf<[ CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">, SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, // TODO: This could be made into C++ function instead. - CPred<"$_op.getOperand(" # i # ")->getType().cast()." + CPred<"$_op.getOperand(" # i # ")->getType().cast()." "getElementType() == $_op.getOperand(" # j # ")->getType()." - "cast().getElementType()">]>; + "cast().getElementType()">]>; // Predicate to verify that the i'th result and the j'th operand have the same // elemental type. -// Type Constraint result`i`'s Vector or Tensor Element type is Same As -// Type Constraint Operand `j`'s Vector or Tensor Element type. +// Type Constraint result`i`'s Element type is Same As Operand `j`'s Element +// type. class TCresVTEtIsSameAsOp : AllOf<[ CPred<"$_op.getNumResults() > " # i>, CPred<"$_op.getNumOperands() > " # j>, SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, // TODO: This could be made into C++ function instead. - CPred<"$_op.getResult(" # i # ")->getType().cast()." + CPred<"$_op.getResult(" # i # ")->getType().cast()." "getElementType() == $_op.getOperand(" # j # ")->getType()." - "cast().getElementType()">]>; + "cast().getElementType()">]>; // Predicate to verify that all the operands at the given `indices` // have the same element type. -// Type Constraint operands' Vector or Tensor Element type are all Same At -// the given `indices`. +// Type Constraint operands' Element type are all Same At the given `indices`. // We query the operands' types into a list and check they are all the same. // Precondition: // 1) all operands involved are of vector or tensor type and @@ -1004,7 +1003,7 @@ class TCresVTEtIsSameAsOp : AllOf<[ class TCopVTEtAreSameAt indices> : CPred<"llvm::is_splat(mlir::functional::map(" "[this](unsigned i) { return this->getOperand(i)->getType()" - ".cast().getElementType(); }, " + ".cast().getElementType(); }, " "llvm::ArrayRef({" # Stringify.result # "})))">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 71e187b71948..788c01af5df5 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -591,7 +591,6 @@ public: /// This class provides verification for ops that are known to have the same /// operand and result element type. /// -/// TODO: This only works for VectorOrTensorType at the moment. template class SameOperandsAndResultElementType : public TraitBase { diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 8043ac6809d5..8589fbd74f04 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -35,7 +35,7 @@ class MLIRContext; namespace detail { struct IntegerTypeStorage; -struct VectorOrTensorTypeStorage; +struct ShapedTypeStorage; struct VectorTypeStorage; struct RankedTensorTypeStorage; struct UnrankedTensorTypeStorage; @@ -180,11 +180,13 @@ public: const llvm::fltSemantics &getFloatSemantics(); }; +// TODO(b/132735995) Add support for MemRef /// This is a common base class between Vector, UnrankedTensor, and RankedTensor -/// types, because many operations work on values of these aggregate types. -class VectorOrTensorType : public Type { +/// types because they share behavior and semantics around shape, rank, and +/// fixed element type. +class ShapedType : public Type { public: - using ImplType = detail::VectorOrTensorTypeStorage; + using ImplType = detail::ShapedTypeStorage; using Type::Type; /// Return the element type. @@ -234,8 +236,8 @@ public: /// Vector types represent multi-dimensional SIMD vectors, and have a fixed /// known constant shape with one or more dimension. -class VectorType : public Type::TypeBase { +class VectorType + : public Type::TypeBase { public: using Base::Base; @@ -268,9 +270,9 @@ public: /// Tensor types represent multi-dimensional arrays, and have two variants: /// RankedTensorType and UnrankedTensorType. -class TensorType : public VectorOrTensorType { +class TensorType : public ShapedType { public: - using VectorOrTensorType::VectorOrTensorType; + using ShapedType::ShapedType; /// Return true if the specified element type is ok in a tensor. static bool isValidElementType(Type type) { diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 0497b8463f06..ffb2150b0534 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -375,7 +375,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let builders = [OpBuilder< "Builder *builder, OperationState *result, Value *aggregate," "ArrayRef indices = {}", [{ - auto resType = aggregate->getType().cast() + auto resType = aggregate->getType().cast() .getElementType(); build(builder, result, resType, aggregate, indices); }]>]; diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 325b0ca93ca7..2fdac8ef26a9 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -168,12 +168,12 @@ struct QuantizedMultiplierSmallerThanOneExp { /// Casts an integer or floating point based type to a new element type. inline Type castElementType(Type t, Type newElementType) { - if (auto vt = t.dyn_cast()) { - switch (vt.getKind()) { + if (auto st = t.dyn_cast()) { + switch (st.getKind()) { case StandardTypes::Kind::Vector: - return VectorType::get(vt.getShape(), newElementType); + return VectorType::get(st.getShape(), newElementType); case StandardTypes::Kind::RankedTensor: - return RankedTensorType::get(vt.getShape(), newElementType); + return RankedTensorType::get(st.getShape(), newElementType); case StandardTypes::Kind::UnrankedTensor: return UnrankedTensorType::get(newElementType); } @@ -185,10 +185,10 @@ inline Type castElementType(Type t, Type newElementType) { /// Creates an IntegerAttr with a type that matches the shape of 't' (which can /// be a primitive/vector/tensor). inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { - if (auto vt = t.dyn_cast()) { - assert(vt.getElementType().isa()); - return SplatElementsAttr::get(vt, - IntegerAttr::get(vt.getElementType(), value)); + if (auto st = t.dyn_cast()) { + assert(st.getElementType().isa()); + return SplatElementsAttr::get(st, + IntegerAttr::get(st.getElementType(), value)); } auto integerType = t.cast(); @@ -211,13 +211,13 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) { /// Creates an IntegerAttr with a type that matches the shape of 't' (which can /// be a primitive/vector/tensor). inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { - if (auto vt = t.dyn_cast()) { - FloatType floatElementType = vt.getElementType().dyn_cast(); + if (auto st = t.dyn_cast()) { + FloatType floatElementType = st.getElementType().dyn_cast(); assert(floatElementType && "float broadcast element type must be float like"); APFloat apValue = convertFloatToType(floatElementType, value); - return SplatElementsAttr::get(vt, - FloatAttr::get(vt.getElementType(), apValue)); + return SplatElementsAttr::get(st, + FloatAttr::get(st.getElementType(), apValue)); } else { auto floatType = t.dyn_cast(); assert(floatType && "float broadcast must be of float type"); diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp index 6ca3b92d0643..1b63b8f4f558 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp @@ -98,8 +98,8 @@ Type QuantizedType::getExpressedType() const { } bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { - if (candidateExpressedType.isa()) { - return candidateExpressedType.cast().getElementType() == + if (candidateExpressedType.isa()) { + return candidateExpressedType.cast().getElementType() == getExpressedType(); } return candidateExpressedType == getExpressedType(); @@ -107,9 +107,9 @@ bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { QuantizedType QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { - if (primitiveOrContainerType.isa()) { + if (primitiveOrContainerType.isa()) { Type elementType = - primitiveOrContainerType.cast().getElementType(); + primitiveOrContainerType.cast().getElementType(); return elementType.dyn_cast(); } return primitiveOrContainerType.dyn_cast(); @@ -139,20 +139,20 @@ Type QuantizedType::castToStorageType(Type quantizedType) { if (quantizedType.isa()) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 return quantizedType.cast().getStorageType(); - } else if (quantizedType.isa()) { + } else if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - VectorOrTensorType vtType = quantizedType.cast(); - if (!vtType.getElementType().isa()) { + ShapedType sType = quantizedType.cast(); + if (!sType.getElementType().isa()) { return nullptr; } Type storageType = - vtType.getElementType().cast().getStorageType(); + sType.getElementType().cast().getStorageType(); if (quantizedType.isa()) { - return RankedTensorType::get(vtType.getShape(), storageType); + return RankedTensorType::get(sType.getShape(), storageType); } else if (quantizedType.isa()) { return UnrankedTensorType::get(storageType); } else if (quantizedType.isa()) { - return VectorType::get(vtType.getShape(), storageType); + return VectorType::get(sType.getShape(), storageType); } } @@ -163,22 +163,21 @@ Type QuantizedType::castFromExpressedType(Type candidateType) { if (candidateType == getExpressedType()) { // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> return *this; - } else if (candidateType.isa()) { - VectorOrTensorType candidateVtType = - candidateType.cast(); - if (candidateVtType.getElementType() != getExpressedType()) { + } else if (candidateType.isa()) { + ShapedType candidateShapedType = candidateType.cast(); + if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } if (candidateType.isa()) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return RankedTensorType::get(candidateVtType.getShape(), *this); + return RankedTensorType::get(candidateShapedType.getShape(), *this); } else if (candidateType.isa()) { // i.e. tensor -> tensor> return UnrankedTensorType::get(*this); } else if (candidateType.isa()) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return VectorType::get(candidateVtType.getShape(), *this); + return VectorType::get(candidateShapedType.getShape(), *this); } } @@ -189,20 +188,20 @@ Type QuantizedType::castToExpressedType(Type quantizedType) { if (quantizedType.isa()) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 return quantizedType.cast().getExpressedType(); - } else if (quantizedType.isa()) { + } else if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - VectorOrTensorType vtType = quantizedType.cast(); - if (!vtType.getElementType().isa()) { + ShapedType sType = quantizedType.cast(); + if (!sType.getElementType().isa()) { return nullptr; } Type expressedType = - vtType.getElementType().cast().getExpressedType(); + sType.getElementType().cast().getExpressedType(); if (quantizedType.isa()) { - return RankedTensorType::get(vtType.getShape(), expressedType); + return RankedTensorType::get(sType.getShape(), expressedType); } else if (quantizedType.isa()) { return UnrankedTensorType::get(expressedType); } else if (quantizedType.isa()) { - return VectorType::get(vtType.getShape(), expressedType); + return VectorType::get(sType.getShape(), expressedType); } } diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index 3685a65f2d82..c50b3075b690 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -56,10 +56,10 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newDenseType = + ShapedType newDenseType = quantizedElementType .castExpressedToStorageType(realFPElementsAttr.getType()) - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (!newDenseType) { return nullptr; } @@ -87,9 +87,9 @@ convertSplatElementsAttr(SplatElementsAttr realSplatAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newSplatType = + ShapedType newSplatType = quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (!newSplatType) { return nullptr; } @@ -116,9 +116,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newSparseType = + ShapedType newSparseType = quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (!newSparseType) { return nullptr; } diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp index d791075f5dba..db8a58489815 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -38,13 +38,13 @@ ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { case StandardTypes::RankedTensor: case StandardTypes::UnrankedTensor: case StandardTypes::Vector: { - Type elementType = inputType.cast().getElementType(); + Type elementType = inputType.cast().getElementType(); if (!isQuantizablePrimitiveType(elementType)) { // Unsupported. return ExpressedToUniformQuantizedConverter{inputType, nullptr}; } return ExpressedToUniformQuantizedConverter{ - inputType, inputType.cast().getElementType()}; + inputType, inputType.cast().getElementType()}; } } } diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 6762a60b0fc5..183ba73ba7aa 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -80,8 +80,8 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { - if (auto vtType = type.dyn_cast()) - return vtType.getShape(); + if (auto sType = type.dyn_cast()) + return sType.getShape(); return {}; } @@ -92,8 +92,8 @@ static ArrayRef getShape(Type type) { Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { // Returns the scalar type out of the given type. auto getScalarType = [](Type type) -> Type { - if (auto vtType = type.dyn_cast()) - return vtType.getElementType(); + if (auto shapedType = type.dyn_cast()) + return shapedType.getElementType(); return type; }; diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index a4691e89cda1..fa9f7397601e 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -315,7 +315,7 @@ static ValueHandle createBinaryHandle( return createBinaryHandle(lhs, rhs); } else if (thisType.isa()) { return createBinaryHandle(lhs, rhs); - } else if (auto aggregateType = thisType.dyn_cast()) { + } else if (auto aggregateType = thisType.dyn_cast()) { if (aggregateType.getElementType().isa()) return createBinaryHandle(lhs, rhs); else if (aggregateType.getElementType().isa()) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 23f1317908a5..f99c09fe5ce5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -201,7 +201,7 @@ void ModuleState::visitType(Type type) { // Visit affine maps in memref type. for (auto map : memref.getAffineMaps()) recordAttributeReference(AffineMapAttr::get(map)); - } else if (auto vecOrTensor = type.dyn_cast()) { + } else if (auto vecOrTensor = type.dyn_cast()) { visitType(vecOrTensor.getElementType()); } } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 881060c90820..b958221a9578 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -313,8 +313,8 @@ FunctionType FunctionAttr::getType() const { // ElementsAttr //===----------------------------------------------------------------------===// -VectorOrTensorType ElementsAttr::getType() const { - return Attribute::getType().cast(); +ShapedType ElementsAttr::getType() const { + return Attribute::getType().cast(); } /// Return the value at the given index. If index does not refer to a valid @@ -339,8 +339,7 @@ Attribute ElementsAttr::getValue(ArrayRef index) const { // SplatElementsAttr //===----------------------------------------------------------------------===// -SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, - Attribute elt) { +SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { assert(elt.getType() == type.getElementType() && "value should be of the given element type"); return Base::get(type.getContext(), StandardAttributes::SplatElements, type, @@ -374,8 +373,7 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const { // DenseElementsAttr //===----------------------------------------------------------------------===// -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, - ArrayRef data) { +DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef data) { assert((static_cast(type.getSizeInBits()) <= data.size() * APInt::APINT_WORD_SIZE) && "Input data bit size should be larger than that type requires"); @@ -394,7 +392,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, } } -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, +DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrFloat() && "expected int or float element type"); @@ -516,7 +514,7 @@ ArrayRef DenseElementsAttr::getRawData() const { // Constructs a dense elements attribute from an array of raw APInt values. // Each APInt value is expected to have the same bitwidth as the element type // of 'type'. -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, +DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(values.size() == type.getNumElements() && "expected 'values' to contain the same number of elements as 'type'"); @@ -586,7 +584,7 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos, /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. -DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, +DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type, ArrayRef values) { return DenseElementsAttr::get(type, values).cast(); } @@ -594,7 +592,7 @@ DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, /// Constructs a dense integer elements attribute from an array of integer /// values. Each value is expected to be within the bitwidth of the element /// type of 'type'. -DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, +DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type, ArrayRef values) { auto eltType = type.getElementType(); size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); @@ -625,7 +623,7 @@ DenseFPElementsAttr::ElementIterator::ElementIterator( // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. -DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type, +DenseFPElementsAttr DenseFPElementsAttr::get(ShapedType type, ArrayRef values) { // Convert the APFloat values to APInt and create a dense elements attribute. std::vector intValues(values.size()); @@ -655,8 +653,7 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const { // OpaqueElementsAttr //===----------------------------------------------------------------------===// -OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, - VectorOrTensorType type, +OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, StringRef bytes) { assert(TensorType::isValidElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); @@ -686,7 +683,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) { // SparseElementsAttr //===----------------------------------------------------------------------===// -SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, +SparseElementsAttr SparseElementsAttr::get(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values) { assert(indices.getType().getElementType().isInteger(64) && diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index a6036a9b1401..574102c82324 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -173,34 +173,32 @@ FunctionAttr Builder::getFunctionAttr(Function *value) { return FunctionAttr::get(value); } -ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, - Attribute elt) { +ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) { return SplatElementsAttr::get(type, elt); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getDenseElementsAttr(ShapedType type, ArrayRef data) { return DenseElementsAttr::get(type, data); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getDenseElementsAttr(ShapedType type, ArrayRef values) { return DenseElementsAttr::get(type, values); } -ElementsAttr Builder::getDenseIntElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type, ArrayRef values) { return DenseIntElementsAttr::get(type, values); } -ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getSparseElementsAttr(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values) { return SparseElementsAttr::get(type, indices, values); } -ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, - VectorOrTensorType type, +ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type, StringRef bytes) { return OpaqueElementsAttr::get(dialect, type, bytes); } @@ -249,7 +247,7 @@ Attribute Builder::getZeroAttr(Type type) { } case StandardTypes::Vector: case StandardTypes::RankedTensor: { - auto vtType = type.cast(); + auto vtType = type.cast(); auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 992c66ea1d1c..2a67f5aa3265 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -775,8 +775,8 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, } /// Returns success if the given two types have the same shape. That is, -/// they are both scalars, or they are both vectors / ranked tensors with -/// the same dimension specifications. The element type does not matter. +/// they are both scalars, or they are both static shaped types with the same +/// dimension specifications. The element type does not matter. static LogicalResult verifyShapeMatch(Type type1, Type type2) { // Check scalar cases if (type1.isIntOrIndexOrFloat()) @@ -787,9 +787,9 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) { return failure(); // Check normal vector/tensor cases - if (auto vtType1 = type1.dyn_cast()) { - auto vtType2 = type2.dyn_cast(); - return success(vtType2 && vtType1.getShape() == vtType2.getShape()); + if (auto sType1 = type1.dyn_cast()) { + auto sType2 = type2.dyn_cast(); + return success(sType2 && sType1.getShape() == sType2.getShape()); } return success(); @@ -818,14 +818,14 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return failure(); - auto type = op->getResult(0)->getType().dyn_cast(); + auto type = op->getResult(0)->getType().dyn_cast(); if (!type) return op->emitOpError("requires vector or tensor type results"); auto elementType = type.getElementType(); // Verify result element type matches first result's element type. for (auto result : drop_begin(op->getResults(), 1)) { - auto resultType = result->getType().dyn_cast(); + auto resultType = result->getType().dyn_cast(); if (!resultType) return op->emitOpError("requires vector or tensor type results"); if (resultType.getElementType() != elementType) @@ -835,7 +835,7 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { // Verify operand's element type matches first result's element type. for (auto operand : op->getOperands()) { - auto operandType = operand->getType().dyn_cast(); + auto operandType = operand->getType().dyn_cast(); if (!operandType) return op->emitOpError("requires vector or tensor type operands"); if (operandType.getElementType() != elementType) diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index ba7d1d3d6543..b279d19a71ec 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -104,18 +104,18 @@ unsigned Type::getIntOrFloatBitWidth() { } //===----------------------------------------------------------------------===// -// VectorOrTensorType +// ShapedType //===----------------------------------------------------------------------===// -Type VectorOrTensorType::getElementType() const { +Type ShapedType::getElementType() const { return static_cast(impl)->elementType; } -unsigned VectorOrTensorType::getElementTypeBitWidth() const { +unsigned ShapedType::getElementTypeBitWidth() const { return getElementType().getIntOrFloatBitWidth(); } -unsigned VectorOrTensorType::getNumElements() const { +unsigned ShapedType::getNumElements() const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: { @@ -127,13 +127,13 @@ unsigned VectorOrTensorType::getNumElements() const { return num; } default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); + llvm_unreachable("not a ShapedType or not ranked"); } } /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. -int64_t VectorOrTensorType::getRank() const { +int64_t ShapedType::getRank() const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: @@ -141,24 +141,24 @@ int64_t VectorOrTensorType::getRank() const { case StandardTypes::UnrankedTensor: return -1; default: - llvm_unreachable("not a VectorOrTensorType"); + llvm_unreachable("not a ShapedType"); } } -int64_t VectorOrTensorType::getDimSize(unsigned i) const { +int64_t ShapedType::getDimSize(unsigned i) const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: return getShape()[i]; default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); + llvm_unreachable("not a ShapedType or not ranked"); } } // Get the number of number of bits require to store a value of the given vector // or tensor types. Compute the value recursively since tensors are allowed to // have vectors as elements. -int64_t VectorOrTensorType::getSizeInBits() const { +int64_t ShapedType::getSizeInBits() const { assert(hasStaticShape() && "cannot get the bit size of an aggregate with a dynamic shape"); @@ -168,23 +168,23 @@ int64_t VectorOrTensorType::getSizeInBits() const { // Tensors can have vectors and other tensors as elements, vectors cannot. assert(!isa() && "unsupported vector element type"); - auto elementVectorOrTensorType = elementType.dyn_cast(); - assert(elementVectorOrTensorType && "unsupported tensor element type"); - return getNumElements() * elementVectorOrTensorType.getSizeInBits(); + auto elementShapedType = elementType.dyn_cast(); + assert(elementShapedType && "unsupported tensor element type"); + return getNumElements() * elementShapedType.getSizeInBits(); } -ArrayRef VectorOrTensorType::getShape() const { +ArrayRef ShapedType::getShape() const { switch (getKind()) { case StandardTypes::Vector: return cast().getShape(); case StandardTypes::RankedTensor: return cast().getShape(); default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); + llvm_unreachable("not a ShapedType or not ranked"); } } -bool VectorOrTensorType::hasStaticShape() const { +bool ShapedType::hasStaticShape() const { if (isa()) return false; return llvm::none_of(getShape(), [](int64_t i) { return i < 0; }); diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h index 870dc7b3236a..541aacd9a722 100644 --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -118,8 +118,8 @@ struct FunctionTypeStorage : public TypeStorage { }; /// VectorOrTensor Type Storage. -struct VectorOrTensorTypeStorage : public TypeStorage { - VectorOrTensorTypeStorage(Type elementType, unsigned subclassData = 0) +struct ShapedTypeStorage : public TypeStorage { + ShapedTypeStorage(Type elementType, unsigned subclassData = 0) : TypeStorage(subclassData), elementType(elementType) {} /// The hash key used for uniquing. @@ -130,11 +130,10 @@ struct VectorOrTensorTypeStorage : public TypeStorage { }; /// Vector Type Storage and Uniquing. -struct VectorTypeStorage : public VectorOrTensorTypeStorage { +struct VectorTypeStorage : public ShapedTypeStorage { VectorTypeStorage(unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : VectorOrTensorTypeStorage(elementTy, shapeSize), - shapeElements(shapeElements) {} + : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} /// The hash key used for uniquing. using KeyTy = std::pair, Type>; @@ -160,11 +159,10 @@ struct VectorTypeStorage : public VectorOrTensorTypeStorage { const int64_t *shapeElements; }; -struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage { +struct RankedTensorTypeStorage : public ShapedTypeStorage { RankedTensorTypeStorage(unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : VectorOrTensorTypeStorage(elementTy, shapeSize), - shapeElements(shapeElements) {} + : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} /// The hash key used for uniquing. using KeyTy = std::pair, Type>; @@ -190,9 +188,9 @@ struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage { const int64_t *shapeElements; }; -struct UnrankedTensorTypeStorage : public VectorOrTensorTypeStorage { - using VectorOrTensorTypeStorage::KeyTy; - using VectorOrTensorTypeStorage::VectorOrTensorTypeStorage; +struct UnrankedTensorTypeStorage : public ShapedTypeStorage { + using ShapedTypeStorage::KeyTy; + using ShapedTypeStorage::ShapedTypeStorage; /// Construction. static UnrankedTensorTypeStorage *construct(TypeStorageAllocator &allocator, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 03b3dd1dbde5..01feec367d81 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -200,9 +200,9 @@ public: // Polyhedral structures. ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, IntegerSet &set); - DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type); + DenseElementsAttr parseDenseElementsAttr(ShapedType type); DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType); - VectorOrTensorType parseVectorOrTensorType(); + ShapedType parseShapedType(); // Location Parsing. @@ -1255,7 +1255,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::comma, "expected ','")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; @@ -1279,7 +1279,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::less, "expected '<' after 'splat'")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; switch (getToken().getKind()) { @@ -1305,7 +1305,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; @@ -1323,7 +1323,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; @@ -1414,7 +1414,7 @@ DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) { /// This method compares the shapes from the parsing result and that from the /// input argument. It returns a constructed dense elements attribute if both /// match. -DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) { +DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) { auto eltTy = type.getElementType(); TensorLiteralParser literalParser(*this, eltTy); if (literalParser.parse()) @@ -1431,27 +1431,26 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) { .cast(); } -/// Vector or tensor type for elements attribute. +/// Shaped type for elements attribute. /// -/// vector-or-tensor-type ::= vector-type | tensor-type +/// shaped-type ::= vector-type | tensor-type /// /// This method also checks the type has static shape and ranked. -VectorOrTensorType Parser::parseVectorOrTensorType() { +ShapedType Parser::parseShapedType() { auto elementType = parseType(); if (!elementType) return nullptr; - auto type = elementType.dyn_cast(); + auto type = elementType.dyn_cast(); if (!type) { - return (emitError("expected elements literal has a tensor or vector type"), - nullptr); + return (emitError("elements literal must be a shaped type"), nullptr); } if (parseToken(Token::comma, "expected ','")) return nullptr; if (!type.hasStaticShape() || type.getRank() == -1) { - return (emitError("tensor literals must be ranked and have static shape"), + return (emitError("shaped literal must be ranked and have static shape"), nullptr); } return type; diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index ce3913ae34f9..3ec07e40f2b7 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -66,18 +66,18 @@ static bool getElementsStatistics(ElementsAttr attr, statistics.minValue = std::numeric_limits::infinity(); statistics.maxValue = -std::numeric_limits::infinity(); - VectorOrTensorType vtType = attr.getType(); - if (!vtType.hasStaticShape()) + ShapedType sType = attr.getType(); + if (!sType.hasStaticShape()) return false; - Type elementTy = vtType.getElementType(); + Type elementTy = sType.getElementType(); if (!elementTy.isa()) return false; llvm::SmallVector indices; - indices.resize(vtType.getRank()); - ArrayRef shape = vtType.getShape(); + indices.resize(sType.getRank()); + ArrayRef shape = sType.getShape(); - auto numElements = vtType.getNumElements(); + auto numElements = sType.getNumElements(); collectElementsStatisticsDim(attr, numElements, shape, indices, 0, statistics); statistics.sampleSize = numElements; diff --git a/mlir/lib/Quantizer/Support/TypeUtils.cpp b/mlir/lib/Quantizer/Support/TypeUtils.cpp index 444322e31e37..fab4e565308e 100644 --- a/mlir/lib/Quantizer/Support/TypeUtils.cpp +++ b/mlir/lib/Quantizer/Support/TypeUtils.cpp @@ -23,8 +23,8 @@ using namespace mlir; using namespace mlir::quantizer; Type mlir::quantizer::getElementOrPrimitiveType(Type t) { - if (auto vtType = t.dyn_cast()) { - return vtType.getElementType(); + if (auto sType = t.dyn_cast()) { + return sType.getElementType(); } else { return t; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 682af6ea99ad..490f6e4e4c80 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1163,9 +1163,9 @@ static LogicalResult verify(ConstantOp &op) { return success(); } - if (type.isa()) { + if (type.isa()) { if (!value.isa()) - return op.emitOpError("requires 'value' to be a vector/tensor constant"); + return op.emitOpError("requires 'value' to be a shaped constant"); return success(); } @@ -1639,7 +1639,7 @@ static ParseResult parseExtractElementOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector indexInfo; - VectorOrTensorType type; + ShapedType type; auto affineIntTy = parser->getBuilder().getIndexType(); return failure( @@ -1656,8 +1656,7 @@ static LogicalResult verify(ExtractElementOp op) { if (op.getNumOperands() == 0) return op.emitOpError("expected an aggregate to index into"); - auto aggregateType = - op.getAggregate()->getType().dyn_cast(); + auto aggregateType = op.getAggregate()->getType().dyn_cast(); if (!aggregateType) return op.emitOpError("first operand must be a vector or tensor"); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index b8f24fe030f8..0fe5029a99de 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -599,14 +599,14 @@ func@n(){^b( func @elementsattr_non_tensor_type() -> () { ^bb0: - "foo"(){bar: dense} : () -> () // expected-error {{expected elements literal has a tensor or vector type}} + "foo"(){bar: dense} : () -> () // expected-error {{elements literal must be a shaped type}} } // ----- func @elementsattr_non_ranked() -> () { ^bb0: - "foo"(){bar: dense, [4]>} : () -> () // expected-error {{tensor literals must be ranked and have static shape}} + "foo"(){bar: dense, [4]>} : () -> () // expected-error {{shaped literal must be ranked and have static shape}} } // ----- diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 1ebe3b1cf9bc..4e0da207efa1 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -33,7 +33,7 @@ def OpC : NS_Op<"op_for_TCopVTEtIs", [ } // CHECK-LABEL: OpC::verify -// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType().isInteger(32))))) +// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType().isInteger(32))))) def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [ @@ -44,7 +44,7 @@ def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [ } // CHECK-LABEL: OpD::verify -// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) +// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) // CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as itself"); @@ -57,7 +57,7 @@ def OpE : NS_Op<"op_for_TCresVTEtIsSameAsOp", [ } // CHECK-LABEL: OpE::verify -// CHECK: if (!((((*this->getOperation()).getNumResults() > 0)) && (((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getResult(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getResult(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) +// CHECK: if (!((((*this->getOperation()).getNumResults() > 0)) && (((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getResult(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getResult(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) // CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as first result"); @@ -107,7 +107,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [ // CHECK-LABEL: OpJ::verify() // CHECK: llvm::is_splat(mlir::functional::map( -// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast().getElementType(); }, +// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast().getElementType(); }, // CHECK-SAME: llvm::ArrayRef({0, 2, 3}))) // CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type"); diff --git a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp index fd2efb205138..3c8642ef429e 100644 --- a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp @@ -49,7 +49,7 @@ template ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef shape, Arg... value) { auto eleType = FloatType::getF32(ctx); - VectorOrTensorType tensorType; + ShapedType tensorType; if (shape.size() == 1 && shape[0] == -1) { tensorType = UnrankedTensorType::get(eleType); } else { @@ -61,7 +61,7 @@ ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef shape, ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, ArrayRef shape) { auto eleType = FloatType::getF32(ctx); - VectorOrTensorType tensorType; + ShapedType tensorType; if (shape.size() == 1 && shape[0] == -1) { tensorType = UnrankedTensorType::get(eleType); } else {