From 090662c5f3572c573cf249844748ecbf11d10dbe Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Thu, 16 May 2019 00:12:45 -0700 Subject: [PATCH] Rename VectorOrTensorType to ShapedType This is in preparation for making it also support/be a parent class of MemRefType. MemRefs have similar shape/rank/element semantics and it would be useful to be able to use these same utilities for them. This CL should not change any semantics and only change variables, types, string literals, and comments. In follow-up CLs I will prepare all callers to handle MemRef types or remove their dependence on ShapedType. Discussion/Rationale in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8 -- PiperOrigin-RevId: 248476449 --- mlir/include/mlir/IR/Attributes.h | 27 +++++------ mlir/include/mlir/IR/Builders.h | 13 +++-- mlir/include/mlir/IR/Matchers.h | 2 +- mlir/include/mlir/IR/OpBase.td | 47 +++++++++---------- mlir/include/mlir/IR/OpDefinition.h | 1 - mlir/include/mlir/IR/StandardTypes.h | 18 +++---- mlir/include/mlir/StandardOps/Ops.td | 2 +- .../Transforms/UniformKernelUtils.h | 24 +++++----- mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp | 43 +++++++++-------- .../Dialect/QuantOps/Utils/QuantizeUtils.cpp | 12 ++--- .../Dialect/QuantOps/Utils/UniformSupport.cpp | 4 +- mlir/lib/Dialect/Traits.cpp | 8 ++-- mlir/lib/EDSC/Builders.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 2 +- mlir/lib/IR/Attributes.cpp | 25 +++++----- mlir/lib/IR/Builders.cpp | 16 +++---- mlir/lib/IR/Operation.cpp | 16 +++---- mlir/lib/IR/StandardTypes.cpp | 32 ++++++------- mlir/lib/IR/TypeDetail.h | 20 ++++---- mlir/lib/Parser/Parser.cpp | 27 +++++------ mlir/lib/Quantizer/Support/Statistics.cpp | 12 ++--- mlir/lib/Quantizer/Support/TypeUtils.cpp | 4 +- mlir/lib/StandardOps/Ops.cpp | 9 ++-- mlir/test/IR/invalid.mlir | 4 +- mlir/test/mlir-tblgen/predicate.td | 8 ++-- .../QuantOps/QuantizationUtilsTest.cpp | 4 +- 26 files changed, 183 insertions(+), 199 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 8305bbbc2452..a8e45dff3885 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -32,7 +32,7 @@ class IntegerSet; class Location; class MLIRContext; class Type; -class VectorOrTensorType; +class ShapedType; namespace detail { @@ -417,7 +417,7 @@ class ElementsAttr : public Attribute { public: using Attribute::Attribute; - VectorOrTensorType getType() const; + ShapedType getType() const; /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. @@ -439,7 +439,7 @@ public: using Base::Base; using ValueType = Attribute; - static SplatElementsAttr get(VectorOrTensorType type, Attribute elt); + static SplatElementsAttr get(ShapedType type, Attribute elt); Attribute getValue() const; /// Method for support type inquiry through isa, cast and dyn_cast. @@ -457,12 +457,11 @@ public: /// It assumes the elements in the input array have been truncated to the bits /// width specified by the element type. - static DenseElementsAttr get(VectorOrTensorType type, ArrayRef data); + static DenseElementsAttr get(ShapedType type, ArrayRef data); // Constructs a dense elements attribute from an array of element values. Each // element attribute value is expected to be an element of 'type'. - static DenseElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseElementsAttr get(ShapedType type, ArrayRef values); /// Returns the number of elements held by this attribute. size_t size() const; @@ -542,7 +541,7 @@ protected: // Constructs a dense elements attribute from an array of raw APInt values. // Each APInt value is expected to have the same bitwidth as the element type // of 'type'. - static DenseElementsAttr get(VectorOrTensorType type, ArrayRef values); + static DenseElementsAttr get(ShapedType type, ArrayRef values); }; /// An attribute that represents a reference to a dense integer vector or tensor @@ -562,14 +561,12 @@ public: /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. - static DenseIntElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseIntElementsAttr get(ShapedType type, ArrayRef values); /// Constructs a dense integer elements attribute from an array of integer /// values. Each value is expected to be within the bitwidth of the element /// type of 'type'. - static DenseIntElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseIntElementsAttr get(ShapedType type, ArrayRef values); /// Gets the integer value of each of the dense elements. void getValues(SmallVectorImpl &values) const; @@ -609,8 +606,7 @@ public: // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. - static DenseFPElementsAttr get(VectorOrTensorType type, - ArrayRef values); + static DenseFPElementsAttr get(ShapedType type, ArrayRef values); /// Gets the float value of each of the dense elements. void getValues(SmallVectorImpl &values) const; @@ -637,7 +633,7 @@ public: using Base::Base; using ValueType = StringRef; - static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type, + static OpaqueElementsAttr get(Dialect *dialect, ShapedType type, StringRef bytes); StringRef getValue() const; @@ -684,8 +680,7 @@ class SparseElementsAttr public: using Base::Base; - static SparseElementsAttr get(VectorOrTensorType type, - DenseIntElementsAttr indices, + static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values); DenseIntElementsAttr getIndices() const; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 7f182e882db0..d852a804723c 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -113,17 +113,16 @@ public: IntegerSetAttr getIntegerSetAttr(IntegerSet set); TypeAttr getTypeAttr(Type type); FunctionAttr getFunctionAttr(Function *value); - ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt); - ElementsAttr getDenseElementsAttr(VectorOrTensorType type, - ArrayRef data); - ElementsAttr getDenseElementsAttr(VectorOrTensorType type, + ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt); + ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef data); + ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef values); - ElementsAttr getDenseIntElementsAttr(VectorOrTensorType type, + ElementsAttr getDenseIntElementsAttr(ShapedType type, ArrayRef values); - ElementsAttr getSparseElementsAttr(VectorOrTensorType type, + ElementsAttr getSparseElementsAttr(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values); - ElementsAttr getOpaqueElementsAttr(Dialect *dialect, VectorOrTensorType type, + ElementsAttr getOpaqueElementsAttr(Dialect *dialect, ShapedType type, StringRef bytes); // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 32-/64-bit float types, and vector or ranked diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 3e337b298192..9a7cb2dfb1a8 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -99,7 +99,7 @@ struct constant_int_op_binder { if (type.isa()) { return attr_value_binder(bind_value).match(attr); } - if (type.isa()) { + if (type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) .match(splatAttr.getValue()); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 1463da28f978..ecd61db0e1d1 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -189,15 +189,15 @@ def IsVectorTypePred : CPred<"$_self.isa()">; // Whether a type is a TensorType. def IsTensorTypePred : CPred<"$_self.isa()">; -// Whether a type is a VectorOrTensorType. -def IsVectorOrTensorTypePred : CPred<"$_self.isa()">; +// Whether a type is a MemRefType. +def IsMemRefTypePred : CPred<"$_self.isa()">; + +// Whether a type is a ShapedType. +def IsShapedTypePred : CPred<"$_self.isa()">; // Whether a type is a TupleType. def IsTupleTypePred : CPred<"$_self.isa()">; -// Whether a type is a MemRefType. -def IsMemRefTypePred : CPred<"$_self.isa()">; - // For a TensorType, verify that it is a statically shaped tensor. def IsStaticShapeTensorTypePred : CPred<"$_self.cast().hasStaticShape()">; @@ -345,7 +345,7 @@ class Vector dims> : ContainerType dimensions = dims; } -def VectorOrTensor : Type; +def VectorOrTensor : Type; // Tensor type. @@ -953,50 +953,49 @@ class Results { // Common op type constraints //===----------------------------------------------------------------------===// -// Type Constraint operand `idx`'s Vector or Tensor Element type is `type`. +// Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : AllOf<[ CPred<"$_op.getNumOperands() > " # idx>, SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # idx # - ")->getType().cast().getElementType()", + ")->getType().cast().getElementType()", type.predicate>]>; // Predicate to verify that the i'th operand and the j'th operand have the same // elemental type. -// Type Constraint operand `i`'s Vector or Tensor Element type is Same As -// operand `j`'s element type. +// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element +// type. class TCopVTEtIsSameAs : AllOf<[ CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">, SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, // TODO: This could be made into C++ function instead. - CPred<"$_op.getOperand(" # i # ")->getType().cast()." + CPred<"$_op.getOperand(" # i # ")->getType().cast()." "getElementType() == $_op.getOperand(" # j # ")->getType()." - "cast().getElementType()">]>; + "cast().getElementType()">]>; // Predicate to verify that the i'th result and the j'th operand have the same // elemental type. -// Type Constraint result`i`'s Vector or Tensor Element type is Same As -// Type Constraint Operand `j`'s Vector or Tensor Element type. +// Type Constraint result`i`'s Element type is Same As Operand `j`'s Element +// type. class TCresVTEtIsSameAsOp : AllOf<[ CPred<"$_op.getNumResults() > " # i>, CPred<"$_op.getNumOperands() > " # j>, SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", - IsVectorOrTensorTypePred>, + IsShapedTypePred>, // TODO: This could be made into C++ function instead. - CPred<"$_op.getResult(" # i # ")->getType().cast()." + CPred<"$_op.getResult(" # i # ")->getType().cast()." "getElementType() == $_op.getOperand(" # j # ")->getType()." - "cast().getElementType()">]>; + "cast().getElementType()">]>; // Predicate to verify that all the operands at the given `indices` // have the same element type. -// Type Constraint operands' Vector or Tensor Element type are all Same At -// the given `indices`. +// Type Constraint operands' Element type are all Same At the given `indices`. // We query the operands' types into a list and check they are all the same. // Precondition: // 1) all operands involved are of vector or tensor type and @@ -1004,7 +1003,7 @@ class TCresVTEtIsSameAsOp : AllOf<[ class TCopVTEtAreSameAt indices> : CPred<"llvm::is_splat(mlir::functional::map(" "[this](unsigned i) { return this->getOperand(i)->getType()" - ".cast().getElementType(); }, " + ".cast().getElementType(); }, " "llvm::ArrayRef({" # Stringify.result # "})))">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 71e187b71948..788c01af5df5 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -591,7 +591,6 @@ public: /// This class provides verification for ops that are known to have the same /// operand and result element type. /// -/// TODO: This only works for VectorOrTensorType at the moment. template class SameOperandsAndResultElementType : public TraitBase { diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 8043ac6809d5..8589fbd74f04 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -35,7 +35,7 @@ class MLIRContext; namespace detail { struct IntegerTypeStorage; -struct VectorOrTensorTypeStorage; +struct ShapedTypeStorage; struct VectorTypeStorage; struct RankedTensorTypeStorage; struct UnrankedTensorTypeStorage; @@ -180,11 +180,13 @@ public: const llvm::fltSemantics &getFloatSemantics(); }; +// TODO(b/132735995) Add support for MemRef /// This is a common base class between Vector, UnrankedTensor, and RankedTensor -/// types, because many operations work on values of these aggregate types. -class VectorOrTensorType : public Type { +/// types because they share behavior and semantics around shape, rank, and +/// fixed element type. +class ShapedType : public Type { public: - using ImplType = detail::VectorOrTensorTypeStorage; + using ImplType = detail::ShapedTypeStorage; using Type::Type; /// Return the element type. @@ -234,8 +236,8 @@ public: /// Vector types represent multi-dimensional SIMD vectors, and have a fixed /// known constant shape with one or more dimension. -class VectorType : public Type::TypeBase { +class VectorType + : public Type::TypeBase { public: using Base::Base; @@ -268,9 +270,9 @@ public: /// Tensor types represent multi-dimensional arrays, and have two variants: /// RankedTensorType and UnrankedTensorType. -class TensorType : public VectorOrTensorType { +class TensorType : public ShapedType { public: - using VectorOrTensorType::VectorOrTensorType; + using ShapedType::ShapedType; /// Return true if the specified element type is ok in a tensor. static bool isValidElementType(Type type) { diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 0497b8463f06..ffb2150b0534 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -375,7 +375,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> { let builders = [OpBuilder< "Builder *builder, OperationState *result, Value *aggregate," "ArrayRef indices = {}", [{ - auto resType = aggregate->getType().cast() + auto resType = aggregate->getType().cast() .getElementType(); build(builder, result, resType, aggregate, indices); }]>]; diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h index 325b0ca93ca7..2fdac8ef26a9 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h @@ -168,12 +168,12 @@ struct QuantizedMultiplierSmallerThanOneExp { /// Casts an integer or floating point based type to a new element type. inline Type castElementType(Type t, Type newElementType) { - if (auto vt = t.dyn_cast()) { - switch (vt.getKind()) { + if (auto st = t.dyn_cast()) { + switch (st.getKind()) { case StandardTypes::Kind::Vector: - return VectorType::get(vt.getShape(), newElementType); + return VectorType::get(st.getShape(), newElementType); case StandardTypes::Kind::RankedTensor: - return RankedTensorType::get(vt.getShape(), newElementType); + return RankedTensorType::get(st.getShape(), newElementType); case StandardTypes::Kind::UnrankedTensor: return UnrankedTensorType::get(newElementType); } @@ -185,10 +185,10 @@ inline Type castElementType(Type t, Type newElementType) { /// Creates an IntegerAttr with a type that matches the shape of 't' (which can /// be a primitive/vector/tensor). inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { - if (auto vt = t.dyn_cast()) { - assert(vt.getElementType().isa()); - return SplatElementsAttr::get(vt, - IntegerAttr::get(vt.getElementType(), value)); + if (auto st = t.dyn_cast()) { + assert(st.getElementType().isa()); + return SplatElementsAttr::get(st, + IntegerAttr::get(st.getElementType(), value)); } auto integerType = t.cast(); @@ -211,13 +211,13 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) { /// Creates an IntegerAttr with a type that matches the shape of 't' (which can /// be a primitive/vector/tensor). inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { - if (auto vt = t.dyn_cast()) { - FloatType floatElementType = vt.getElementType().dyn_cast(); + if (auto st = t.dyn_cast()) { + FloatType floatElementType = st.getElementType().dyn_cast(); assert(floatElementType && "float broadcast element type must be float like"); APFloat apValue = convertFloatToType(floatElementType, value); - return SplatElementsAttr::get(vt, - FloatAttr::get(vt.getElementType(), apValue)); + return SplatElementsAttr::get(st, + FloatAttr::get(st.getElementType(), apValue)); } else { auto floatType = t.dyn_cast(); assert(floatType && "float broadcast must be of float type"); diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp index 6ca3b92d0643..1b63b8f4f558 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp @@ -98,8 +98,8 @@ Type QuantizedType::getExpressedType() const { } bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { - if (candidateExpressedType.isa()) { - return candidateExpressedType.cast().getElementType() == + if (candidateExpressedType.isa()) { + return candidateExpressedType.cast().getElementType() == getExpressedType(); } return candidateExpressedType == getExpressedType(); @@ -107,9 +107,9 @@ bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { QuantizedType QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { - if (primitiveOrContainerType.isa()) { + if (primitiveOrContainerType.isa()) { Type elementType = - primitiveOrContainerType.cast().getElementType(); + primitiveOrContainerType.cast().getElementType(); return elementType.dyn_cast(); } return primitiveOrContainerType.dyn_cast(); @@ -139,20 +139,20 @@ Type QuantizedType::castToStorageType(Type quantizedType) { if (quantizedType.isa()) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 return quantizedType.cast().getStorageType(); - } else if (quantizedType.isa()) { + } else if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - VectorOrTensorType vtType = quantizedType.cast(); - if (!vtType.getElementType().isa()) { + ShapedType sType = quantizedType.cast(); + if (!sType.getElementType().isa()) { return nullptr; } Type storageType = - vtType.getElementType().cast().getStorageType(); + sType.getElementType().cast().getStorageType(); if (quantizedType.isa()) { - return RankedTensorType::get(vtType.getShape(), storageType); + return RankedTensorType::get(sType.getShape(), storageType); } else if (quantizedType.isa()) { return UnrankedTensorType::get(storageType); } else if (quantizedType.isa()) { - return VectorType::get(vtType.getShape(), storageType); + return VectorType::get(sType.getShape(), storageType); } } @@ -163,22 +163,21 @@ Type QuantizedType::castFromExpressedType(Type candidateType) { if (candidateType == getExpressedType()) { // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> return *this; - } else if (candidateType.isa()) { - VectorOrTensorType candidateVtType = - candidateType.cast(); - if (candidateVtType.getElementType() != getExpressedType()) { + } else if (candidateType.isa()) { + ShapedType candidateShapedType = candidateType.cast(); + if (candidateShapedType.getElementType() != getExpressedType()) { return nullptr; } if (candidateType.isa()) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return RankedTensorType::get(candidateVtType.getShape(), *this); + return RankedTensorType::get(candidateShapedType.getShape(), *this); } else if (candidateType.isa()) { // i.e. tensor -> tensor> return UnrankedTensorType::get(*this); } else if (candidateType.isa()) { // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - return VectorType::get(candidateVtType.getShape(), *this); + return VectorType::get(candidateShapedType.getShape(), *this); } } @@ -189,20 +188,20 @@ Type QuantizedType::castToExpressedType(Type quantizedType) { if (quantizedType.isa()) { // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 return quantizedType.cast().getExpressedType(); - } else if (quantizedType.isa()) { + } else if (quantizedType.isa()) { // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> - VectorOrTensorType vtType = quantizedType.cast(); - if (!vtType.getElementType().isa()) { + ShapedType sType = quantizedType.cast(); + if (!sType.getElementType().isa()) { return nullptr; } Type expressedType = - vtType.getElementType().cast().getExpressedType(); + sType.getElementType().cast().getExpressedType(); if (quantizedType.isa()) { - return RankedTensorType::get(vtType.getShape(), expressedType); + return RankedTensorType::get(sType.getShape(), expressedType); } else if (quantizedType.isa()) { return UnrankedTensorType::get(expressedType); } else if (quantizedType.isa()) { - return VectorType::get(vtType.getShape(), expressedType); + return VectorType::get(sType.getShape(), expressedType); } } diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index 3685a65f2d82..c50b3075b690 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -56,10 +56,10 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newDenseType = + ShapedType newDenseType = quantizedElementType .castExpressedToStorageType(realFPElementsAttr.getType()) - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (!newDenseType) { return nullptr; } @@ -87,9 +87,9 @@ convertSplatElementsAttr(SplatElementsAttr realSplatAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newSplatType = + ShapedType newSplatType = quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (!newSplatType) { return nullptr; } @@ -116,9 +116,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr, // Cast from an expressed-type-based type to storage-type-based type, // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). - VectorOrTensorType newSparseType = + ShapedType newSparseType = quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (!newSparseType) { return nullptr; } diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp index d791075f5dba..db8a58489815 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -38,13 +38,13 @@ ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { case StandardTypes::RankedTensor: case StandardTypes::UnrankedTensor: case StandardTypes::Vector: { - Type elementType = inputType.cast().getElementType(); + Type elementType = inputType.cast().getElementType(); if (!isQuantizablePrimitiveType(elementType)) { // Unsupported. return ExpressedToUniformQuantizedConverter{inputType, nullptr}; } return ExpressedToUniformQuantizedConverter{ - inputType, inputType.cast().getElementType()}; + inputType, inputType.cast().getElementType()}; } } } diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 6762a60b0fc5..183ba73ba7aa 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -80,8 +80,8 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { - if (auto vtType = type.dyn_cast()) - return vtType.getShape(); + if (auto sType = type.dyn_cast()) + return sType.getShape(); return {}; } @@ -92,8 +92,8 @@ static ArrayRef getShape(Type type) { Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { // Returns the scalar type out of the given type. auto getScalarType = [](Type type) -> Type { - if (auto vtType = type.dyn_cast()) - return vtType.getElementType(); + if (auto shapedType = type.dyn_cast()) + return shapedType.getElementType(); return type; }; diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index a4691e89cda1..fa9f7397601e 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -315,7 +315,7 @@ static ValueHandle createBinaryHandle( return createBinaryHandle(lhs, rhs); } else if (thisType.isa()) { return createBinaryHandle(lhs, rhs); - } else if (auto aggregateType = thisType.dyn_cast()) { + } else if (auto aggregateType = thisType.dyn_cast()) { if (aggregateType.getElementType().isa()) return createBinaryHandle(lhs, rhs); else if (aggregateType.getElementType().isa()) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 23f1317908a5..f99c09fe5ce5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -201,7 +201,7 @@ void ModuleState::visitType(Type type) { // Visit affine maps in memref type. for (auto map : memref.getAffineMaps()) recordAttributeReference(AffineMapAttr::get(map)); - } else if (auto vecOrTensor = type.dyn_cast()) { + } else if (auto vecOrTensor = type.dyn_cast()) { visitType(vecOrTensor.getElementType()); } } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 881060c90820..b958221a9578 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -313,8 +313,8 @@ FunctionType FunctionAttr::getType() const { // ElementsAttr //===----------------------------------------------------------------------===// -VectorOrTensorType ElementsAttr::getType() const { - return Attribute::getType().cast(); +ShapedType ElementsAttr::getType() const { + return Attribute::getType().cast(); } /// Return the value at the given index. If index does not refer to a valid @@ -339,8 +339,7 @@ Attribute ElementsAttr::getValue(ArrayRef index) const { // SplatElementsAttr //===----------------------------------------------------------------------===// -SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, - Attribute elt) { +SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { assert(elt.getType() == type.getElementType() && "value should be of the given element type"); return Base::get(type.getContext(), StandardAttributes::SplatElements, type, @@ -374,8 +373,7 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const { // DenseElementsAttr //===----------------------------------------------------------------------===// -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, - ArrayRef data) { +DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef data) { assert((static_cast(type.getSizeInBits()) <= data.size() * APInt::APINT_WORD_SIZE) && "Input data bit size should be larger than that type requires"); @@ -394,7 +392,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, } } -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, +DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrFloat() && "expected int or float element type"); @@ -516,7 +514,7 @@ ArrayRef DenseElementsAttr::getRawData() const { // Constructs a dense elements attribute from an array of raw APInt values. // Each APInt value is expected to have the same bitwidth as the element type // of 'type'. -DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, +DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(values.size() == type.getNumElements() && "expected 'values' to contain the same number of elements as 'type'"); @@ -586,7 +584,7 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos, /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. -DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, +DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type, ArrayRef values) { return DenseElementsAttr::get(type, values).cast(); } @@ -594,7 +592,7 @@ DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, /// Constructs a dense integer elements attribute from an array of integer /// values. Each value is expected to be within the bitwidth of the element /// type of 'type'. -DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type, +DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type, ArrayRef values) { auto eltType = type.getElementType(); size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); @@ -625,7 +623,7 @@ DenseFPElementsAttr::ElementIterator::ElementIterator( // Constructs a dense float elements attribute from an array of APFloat // values. Each APFloat value is expected to have the same bitwidth as the // element type of 'type'. -DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type, +DenseFPElementsAttr DenseFPElementsAttr::get(ShapedType type, ArrayRef values) { // Convert the APFloat values to APInt and create a dense elements attribute. std::vector intValues(values.size()); @@ -655,8 +653,7 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const { // OpaqueElementsAttr //===----------------------------------------------------------------------===// -OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, - VectorOrTensorType type, +OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, StringRef bytes) { assert(TensorType::isValidElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); @@ -686,7 +683,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) { // SparseElementsAttr //===----------------------------------------------------------------------===// -SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type, +SparseElementsAttr SparseElementsAttr::get(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values) { assert(indices.getType().getElementType().isInteger(64) && diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index a6036a9b1401..574102c82324 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -173,34 +173,32 @@ FunctionAttr Builder::getFunctionAttr(Function *value) { return FunctionAttr::get(value); } -ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, - Attribute elt) { +ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) { return SplatElementsAttr::get(type, elt); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getDenseElementsAttr(ShapedType type, ArrayRef data) { return DenseElementsAttr::get(type, data); } -ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getDenseElementsAttr(ShapedType type, ArrayRef values) { return DenseElementsAttr::get(type, values); } -ElementsAttr Builder::getDenseIntElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type, ArrayRef values) { return DenseIntElementsAttr::get(type, values); } -ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getSparseElementsAttr(ShapedType type, DenseIntElementsAttr indices, DenseElementsAttr values) { return SparseElementsAttr::get(type, indices, values); } -ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, - VectorOrTensorType type, +ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type, StringRef bytes) { return OpaqueElementsAttr::get(dialect, type, bytes); } @@ -249,7 +247,7 @@ Attribute Builder::getZeroAttr(Type type) { } case StandardTypes::Vector: case StandardTypes::RankedTensor: { - auto vtType = type.cast(); + auto vtType = type.cast(); auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 992c66ea1d1c..2a67f5aa3265 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -775,8 +775,8 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, } /// Returns success if the given two types have the same shape. That is, -/// they are both scalars, or they are both vectors / ranked tensors with -/// the same dimension specifications. The element type does not matter. +/// they are both scalars, or they are both static shaped types with the same +/// dimension specifications. The element type does not matter. static LogicalResult verifyShapeMatch(Type type1, Type type2) { // Check scalar cases if (type1.isIntOrIndexOrFloat()) @@ -787,9 +787,9 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) { return failure(); // Check normal vector/tensor cases - if (auto vtType1 = type1.dyn_cast()) { - auto vtType2 = type2.dyn_cast(); - return success(vtType2 && vtType1.getShape() == vtType2.getShape()); + if (auto sType1 = type1.dyn_cast()) { + auto sType2 = type2.dyn_cast(); + return success(sType2 && sType1.getShape() == sType2.getShape()); } return success(); @@ -818,14 +818,14 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { if (op->getNumOperands() == 0 || op->getNumResults() == 0) return failure(); - auto type = op->getResult(0)->getType().dyn_cast(); + auto type = op->getResult(0)->getType().dyn_cast(); if (!type) return op->emitOpError("requires vector or tensor type results"); auto elementType = type.getElementType(); // Verify result element type matches first result's element type. for (auto result : drop_begin(op->getResults(), 1)) { - auto resultType = result->getType().dyn_cast(); + auto resultType = result->getType().dyn_cast(); if (!resultType) return op->emitOpError("requires vector or tensor type results"); if (resultType.getElementType() != elementType) @@ -835,7 +835,7 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { // Verify operand's element type matches first result's element type. for (auto operand : op->getOperands()) { - auto operandType = operand->getType().dyn_cast(); + auto operandType = operand->getType().dyn_cast(); if (!operandType) return op->emitOpError("requires vector or tensor type operands"); if (operandType.getElementType() != elementType) diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index ba7d1d3d6543..b279d19a71ec 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -104,18 +104,18 @@ unsigned Type::getIntOrFloatBitWidth() { } //===----------------------------------------------------------------------===// -// VectorOrTensorType +// ShapedType //===----------------------------------------------------------------------===// -Type VectorOrTensorType::getElementType() const { +Type ShapedType::getElementType() const { return static_cast(impl)->elementType; } -unsigned VectorOrTensorType::getElementTypeBitWidth() const { +unsigned ShapedType::getElementTypeBitWidth() const { return getElementType().getIntOrFloatBitWidth(); } -unsigned VectorOrTensorType::getNumElements() const { +unsigned ShapedType::getNumElements() const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: { @@ -127,13 +127,13 @@ unsigned VectorOrTensorType::getNumElements() const { return num; } default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); + llvm_unreachable("not a ShapedType or not ranked"); } } /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. -int64_t VectorOrTensorType::getRank() const { +int64_t ShapedType::getRank() const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: @@ -141,24 +141,24 @@ int64_t VectorOrTensorType::getRank() const { case StandardTypes::UnrankedTensor: return -1; default: - llvm_unreachable("not a VectorOrTensorType"); + llvm_unreachable("not a ShapedType"); } } -int64_t VectorOrTensorType::getDimSize(unsigned i) const { +int64_t ShapedType::getDimSize(unsigned i) const { switch (getKind()) { case StandardTypes::Vector: case StandardTypes::RankedTensor: return getShape()[i]; default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); + llvm_unreachable("not a ShapedType or not ranked"); } } // Get the number of number of bits require to store a value of the given vector // or tensor types. Compute the value recursively since tensors are allowed to // have vectors as elements. -int64_t VectorOrTensorType::getSizeInBits() const { +int64_t ShapedType::getSizeInBits() const { assert(hasStaticShape() && "cannot get the bit size of an aggregate with a dynamic shape"); @@ -168,23 +168,23 @@ int64_t VectorOrTensorType::getSizeInBits() const { // Tensors can have vectors and other tensors as elements, vectors cannot. assert(!isa() && "unsupported vector element type"); - auto elementVectorOrTensorType = elementType.dyn_cast(); - assert(elementVectorOrTensorType && "unsupported tensor element type"); - return getNumElements() * elementVectorOrTensorType.getSizeInBits(); + auto elementShapedType = elementType.dyn_cast(); + assert(elementShapedType && "unsupported tensor element type"); + return getNumElements() * elementShapedType.getSizeInBits(); } -ArrayRef VectorOrTensorType::getShape() const { +ArrayRef ShapedType::getShape() const { switch (getKind()) { case StandardTypes::Vector: return cast().getShape(); case StandardTypes::RankedTensor: return cast().getShape(); default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); + llvm_unreachable("not a ShapedType or not ranked"); } } -bool VectorOrTensorType::hasStaticShape() const { +bool ShapedType::hasStaticShape() const { if (isa()) return false; return llvm::none_of(getShape(), [](int64_t i) { return i < 0; }); diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h index 870dc7b3236a..541aacd9a722 100644 --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -118,8 +118,8 @@ struct FunctionTypeStorage : public TypeStorage { }; /// VectorOrTensor Type Storage. -struct VectorOrTensorTypeStorage : public TypeStorage { - VectorOrTensorTypeStorage(Type elementType, unsigned subclassData = 0) +struct ShapedTypeStorage : public TypeStorage { + ShapedTypeStorage(Type elementType, unsigned subclassData = 0) : TypeStorage(subclassData), elementType(elementType) {} /// The hash key used for uniquing. @@ -130,11 +130,10 @@ struct VectorOrTensorTypeStorage : public TypeStorage { }; /// Vector Type Storage and Uniquing. -struct VectorTypeStorage : public VectorOrTensorTypeStorage { +struct VectorTypeStorage : public ShapedTypeStorage { VectorTypeStorage(unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : VectorOrTensorTypeStorage(elementTy, shapeSize), - shapeElements(shapeElements) {} + : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} /// The hash key used for uniquing. using KeyTy = std::pair, Type>; @@ -160,11 +159,10 @@ struct VectorTypeStorage : public VectorOrTensorTypeStorage { const int64_t *shapeElements; }; -struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage { +struct RankedTensorTypeStorage : public ShapedTypeStorage { RankedTensorTypeStorage(unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : VectorOrTensorTypeStorage(elementTy, shapeSize), - shapeElements(shapeElements) {} + : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} /// The hash key used for uniquing. using KeyTy = std::pair, Type>; @@ -190,9 +188,9 @@ struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage { const int64_t *shapeElements; }; -struct UnrankedTensorTypeStorage : public VectorOrTensorTypeStorage { - using VectorOrTensorTypeStorage::KeyTy; - using VectorOrTensorTypeStorage::VectorOrTensorTypeStorage; +struct UnrankedTensorTypeStorage : public ShapedTypeStorage { + using ShapedTypeStorage::KeyTy; + using ShapedTypeStorage::ShapedTypeStorage; /// Construction. static UnrankedTensorTypeStorage *construct(TypeStorageAllocator &allocator, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 03b3dd1dbde5..01feec367d81 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -200,9 +200,9 @@ public: // Polyhedral structures. ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, IntegerSet &set); - DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type); + DenseElementsAttr parseDenseElementsAttr(ShapedType type); DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType); - VectorOrTensorType parseVectorOrTensorType(); + ShapedType parseShapedType(); // Location Parsing. @@ -1255,7 +1255,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::comma, "expected ','")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; @@ -1279,7 +1279,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::less, "expected '<' after 'splat'")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; switch (getToken().getKind()) { @@ -1305,7 +1305,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; @@ -1323,7 +1323,7 @@ Attribute Parser::parseAttribute(Type type) { if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; - auto type = parseVectorOrTensorType(); + auto type = parseShapedType(); if (!type) return nullptr; @@ -1414,7 +1414,7 @@ DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) { /// This method compares the shapes from the parsing result and that from the /// input argument. It returns a constructed dense elements attribute if both /// match. -DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) { +DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) { auto eltTy = type.getElementType(); TensorLiteralParser literalParser(*this, eltTy); if (literalParser.parse()) @@ -1431,27 +1431,26 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) { .cast(); } -/// Vector or tensor type for elements attribute. +/// Shaped type for elements attribute. /// -/// vector-or-tensor-type ::= vector-type | tensor-type +/// shaped-type ::= vector-type | tensor-type /// /// This method also checks the type has static shape and ranked. -VectorOrTensorType Parser::parseVectorOrTensorType() { +ShapedType Parser::parseShapedType() { auto elementType = parseType(); if (!elementType) return nullptr; - auto type = elementType.dyn_cast(); + auto type = elementType.dyn_cast(); if (!type) { - return (emitError("expected elements literal has a tensor or vector type"), - nullptr); + return (emitError("elements literal must be a shaped type"), nullptr); } if (parseToken(Token::comma, "expected ','")) return nullptr; if (!type.hasStaticShape() || type.getRank() == -1) { - return (emitError("tensor literals must be ranked and have static shape"), + return (emitError("shaped literal must be ranked and have static shape"), nullptr); } return type; diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index ce3913ae34f9..3ec07e40f2b7 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -66,18 +66,18 @@ static bool getElementsStatistics(ElementsAttr attr, statistics.minValue = std::numeric_limits::infinity(); statistics.maxValue = -std::numeric_limits::infinity(); - VectorOrTensorType vtType = attr.getType(); - if (!vtType.hasStaticShape()) + ShapedType sType = attr.getType(); + if (!sType.hasStaticShape()) return false; - Type elementTy = vtType.getElementType(); + Type elementTy = sType.getElementType(); if (!elementTy.isa()) return false; llvm::SmallVector indices; - indices.resize(vtType.getRank()); - ArrayRef shape = vtType.getShape(); + indices.resize(sType.getRank()); + ArrayRef shape = sType.getShape(); - auto numElements = vtType.getNumElements(); + auto numElements = sType.getNumElements(); collectElementsStatisticsDim(attr, numElements, shape, indices, 0, statistics); statistics.sampleSize = numElements; diff --git a/mlir/lib/Quantizer/Support/TypeUtils.cpp b/mlir/lib/Quantizer/Support/TypeUtils.cpp index 444322e31e37..fab4e565308e 100644 --- a/mlir/lib/Quantizer/Support/TypeUtils.cpp +++ b/mlir/lib/Quantizer/Support/TypeUtils.cpp @@ -23,8 +23,8 @@ using namespace mlir; using namespace mlir::quantizer; Type mlir::quantizer::getElementOrPrimitiveType(Type t) { - if (auto vtType = t.dyn_cast()) { - return vtType.getElementType(); + if (auto sType = t.dyn_cast()) { + return sType.getElementType(); } else { return t; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 682af6ea99ad..490f6e4e4c80 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1163,9 +1163,9 @@ static LogicalResult verify(ConstantOp &op) { return success(); } - if (type.isa()) { + if (type.isa()) { if (!value.isa()) - return op.emitOpError("requires 'value' to be a vector/tensor constant"); + return op.emitOpError("requires 'value' to be a shaped constant"); return success(); } @@ -1639,7 +1639,7 @@ static ParseResult parseExtractElementOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector indexInfo; - VectorOrTensorType type; + ShapedType type; auto affineIntTy = parser->getBuilder().getIndexType(); return failure( @@ -1656,8 +1656,7 @@ static LogicalResult verify(ExtractElementOp op) { if (op.getNumOperands() == 0) return op.emitOpError("expected an aggregate to index into"); - auto aggregateType = - op.getAggregate()->getType().dyn_cast(); + auto aggregateType = op.getAggregate()->getType().dyn_cast(); if (!aggregateType) return op.emitOpError("first operand must be a vector or tensor"); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index b8f24fe030f8..0fe5029a99de 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -599,14 +599,14 @@ func@n(){^b( func @elementsattr_non_tensor_type() -> () { ^bb0: - "foo"(){bar: dense} : () -> () // expected-error {{expected elements literal has a tensor or vector type}} + "foo"(){bar: dense} : () -> () // expected-error {{elements literal must be a shaped type}} } // ----- func @elementsattr_non_ranked() -> () { ^bb0: - "foo"(){bar: dense, [4]>} : () -> () // expected-error {{tensor literals must be ranked and have static shape}} + "foo"(){bar: dense, [4]>} : () -> () // expected-error {{shaped literal must be ranked and have static shape}} } // ----- diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 1ebe3b1cf9bc..4e0da207efa1 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -33,7 +33,7 @@ def OpC : NS_Op<"op_for_TCopVTEtIs", [ } // CHECK-LABEL: OpC::verify -// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType().isInteger(32))))) +// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType().isInteger(32))))) def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [ @@ -44,7 +44,7 @@ def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [ } // CHECK-LABEL: OpD::verify -// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) +// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) // CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as itself"); @@ -57,7 +57,7 @@ def OpE : NS_Op<"op_for_TCresVTEtIsSameAsOp", [ } // CHECK-LABEL: OpE::verify -// CHECK: if (!((((*this->getOperation()).getNumResults() > 0)) && (((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getResult(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getResult(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) +// CHECK: if (!((((*this->getOperation()).getNumResults() > 0)) && (((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getResult(0)->getType().isa())) && (((*this->getOperation()).getOperand(0)->getType().isa())) && (((*this->getOperation()).getResult(0)->getType().cast().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast().getElementType())))) // CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as first result"); @@ -107,7 +107,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [ // CHECK-LABEL: OpJ::verify() // CHECK: llvm::is_splat(mlir::functional::map( -// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast().getElementType(); }, +// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast().getElementType(); }, // CHECK-SAME: llvm::ArrayRef({0, 2, 3}))) // CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type"); diff --git a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp index fd2efb205138..3c8642ef429e 100644 --- a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp @@ -49,7 +49,7 @@ template ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef shape, Arg... value) { auto eleType = FloatType::getF32(ctx); - VectorOrTensorType tensorType; + ShapedType tensorType; if (shape.size() == 1 && shape[0] == -1) { tensorType = UnrankedTensorType::get(eleType); } else { @@ -61,7 +61,7 @@ ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef shape, ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, ArrayRef shape) { auto eleType = FloatType::getF32(ctx); - VectorOrTensorType tensorType; + ShapedType tensorType; if (shape.size() == 1 && shape[0] == -1) { tensorType = UnrankedTensorType::get(eleType); } else {