forked from OSchip/llvm-project
Rename VectorOrTensorType to ShapedType
This is in preparation for making it also support/be a parent class of MemRefType. MemRefs have similar shape/rank/element semantics and it would be useful to be able to use these same utilities for them. This CL should not change any semantics and only change variables, types, string literals, and comments. In follow-up CLs I will prepare all callers to handle MemRef types or remove their dependence on ShapedType. Discussion/Rationale in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8 -- PiperOrigin-RevId: 248476449
This commit is contained in:
parent
b3888fa9cc
commit
090662c5f3
|
@ -32,7 +32,7 @@ class IntegerSet;
|
|||
class Location;
|
||||
class MLIRContext;
|
||||
class Type;
|
||||
class VectorOrTensorType;
|
||||
class ShapedType;
|
||||
|
||||
namespace detail {
|
||||
|
||||
|
@ -417,7 +417,7 @@ class ElementsAttr : public Attribute {
|
|||
public:
|
||||
using Attribute::Attribute;
|
||||
|
||||
VectorOrTensorType getType() const;
|
||||
ShapedType getType() const;
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
|
@ -439,7 +439,7 @@ public:
|
|||
using Base::Base;
|
||||
using ValueType = Attribute;
|
||||
|
||||
static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
|
||||
static SplatElementsAttr get(ShapedType type, Attribute elt);
|
||||
Attribute getValue() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
|
@ -457,12 +457,11 @@ public:
|
|||
|
||||
/// It assumes the elements in the input array have been truncated to the bits
|
||||
/// width specified by the element type.
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
|
||||
|
||||
// Constructs a dense elements attribute from an array of element values. Each
|
||||
// element attribute value is expected to be an element of 'type'.
|
||||
static DenseElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values);
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
size_t size() const;
|
||||
|
@ -542,7 +541,7 @@ protected:
|
|||
// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
// Each APInt value is expected to have the same bitwidth as the element type
|
||||
// of 'type'.
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<APInt> values);
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||
};
|
||||
|
||||
/// An attribute that represents a reference to a dense integer vector or tensor
|
||||
|
@ -562,14 +561,12 @@ public:
|
|||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'.
|
||||
static DenseIntElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<APInt> values);
|
||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of integer
|
||||
/// values. Each value is expected to be within the bitwidth of the element
|
||||
/// type of 'type'.
|
||||
static DenseIntElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<int64_t> values);
|
||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
|
||||
|
||||
/// Gets the integer value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APInt> &values) const;
|
||||
|
@ -609,8 +606,7 @@ public:
|
|||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
static DenseFPElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<APFloat> values);
|
||||
static DenseFPElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
|
||||
|
||||
/// Gets the float value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APFloat> &values) const;
|
||||
|
@ -637,7 +633,7 @@ public:
|
|||
using Base::Base;
|
||||
using ValueType = StringRef;
|
||||
|
||||
static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type,
|
||||
static OpaqueElementsAttr get(Dialect *dialect, ShapedType type,
|
||||
StringRef bytes);
|
||||
|
||||
StringRef getValue() const;
|
||||
|
@ -684,8 +680,7 @@ class SparseElementsAttr
|
|||
public:
|
||||
using Base::Base;
|
||||
|
||||
static SparseElementsAttr get(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
|
||||
DenseIntElementsAttr getIndices() const;
|
||||
|
|
|
@ -113,17 +113,16 @@ public:
|
|||
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
||||
TypeAttr getTypeAttr(Type type);
|
||||
FunctionAttr getFunctionAttr(Function *value);
|
||||
ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
|
||||
ArrayRef<char> data);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values);
|
||||
ElementsAttr getDenseIntElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr getDenseIntElementsAttr(ShapedType type,
|
||||
ArrayRef<int64_t> values);
|
||||
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr getSparseElementsAttr(ShapedType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
ElementsAttr getOpaqueElementsAttr(Dialect *dialect, VectorOrTensorType type,
|
||||
ElementsAttr getOpaqueElementsAttr(Dialect *dialect, ShapedType type,
|
||||
StringRef bytes);
|
||||
// Returns a 0-valued attribute of the given `type`. This function only
|
||||
// supports boolean, integer, and 32-/64-bit float types, and vector or ranked
|
||||
|
|
|
@ -99,7 +99,7 @@ struct constant_int_op_binder {
|
|||
if (type.isa<IntegerType>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
|
||||
}
|
||||
if (type.isa<VectorOrTensorType>()) {
|
||||
if (type.isa<ShapedType>()) {
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value)
|
||||
.match(splatAttr.getValue());
|
||||
|
|
|
@ -189,15 +189,15 @@ def IsVectorTypePred : CPred<"$_self.isa<VectorType>()">;
|
|||
// Whether a type is a TensorType.
|
||||
def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
|
||||
|
||||
// Whether a type is a VectorOrTensorType.
|
||||
def IsVectorOrTensorTypePred : CPred<"$_self.isa<VectorOrTensorType>()">;
|
||||
// Whether a type is a MemRefType.
|
||||
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
||||
|
||||
// Whether a type is a ShapedType.
|
||||
def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">;
|
||||
|
||||
// Whether a type is a TupleType.
|
||||
def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">;
|
||||
|
||||
// Whether a type is a MemRefType.
|
||||
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
||||
|
||||
// For a TensorType, verify that it is a statically shaped tensor.
|
||||
def IsStaticShapeTensorTypePred :
|
||||
CPred<"$_self.cast<TensorType>().hasStaticShape()">;
|
||||
|
@ -345,7 +345,7 @@ class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
|
|||
list<int> dimensions = dims;
|
||||
}
|
||||
|
||||
def VectorOrTensor : Type<IsVectorOrTensorTypePred, "vector or tensor">;
|
||||
def VectorOrTensor : Type<IsShapedTypePred, "vector or tensor">;
|
||||
|
||||
// Tensor type.
|
||||
|
||||
|
@ -953,50 +953,49 @@ class Results<dag rets> {
|
|||
// Common op type constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Type Constraint operand `idx`'s Vector or Tensor Element type is `type`.
|
||||
// Type Constraint operand `idx`'s Element type is `type`.
|
||||
class TCopVTEtIs<int idx, Type type> : AllOf<[
|
||||
CPred<"$_op.getNumOperands() > " # idx>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
IsShapedTypePred>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # idx #
|
||||
")->getType().cast<VectorOrTensorType>().getElementType()",
|
||||
")->getType().cast<ShapedType>().getElementType()",
|
||||
type.predicate>]>;
|
||||
|
||||
// Predicate to verify that the i'th operand and the j'th operand have the same
|
||||
// elemental type.
|
||||
// Type Constraint operand `i`'s Vector or Tensor Element type is Same As
|
||||
// operand `j`'s element type.
|
||||
// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element
|
||||
// type.
|
||||
class TCopVTEtIsSameAs<int i, int j> : AllOf<[
|
||||
CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
IsShapedTypePred>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
IsShapedTypePred>,
|
||||
// TODO: This could be made into C++ function instead.
|
||||
CPred<"$_op.getOperand(" # i # ")->getType().cast<VectorOrTensorType>()."
|
||||
CPred<"$_op.getOperand(" # i # ")->getType().cast<ShapedType>()."
|
||||
"getElementType() == $_op.getOperand(" # j # ")->getType()."
|
||||
"cast<VectorOrTensorType>().getElementType()">]>;
|
||||
"cast<ShapedType>().getElementType()">]>;
|
||||
|
||||
// Predicate to verify that the i'th result and the j'th operand have the same
|
||||
// elemental type.
|
||||
// Type Constraint result`i`'s Vector or Tensor Element type is Same As
|
||||
// Type Constraint Operand `j`'s Vector or Tensor Element type.
|
||||
// Type Constraint result`i`'s Element type is Same As Operand `j`'s Element
|
||||
// type.
|
||||
class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
|
||||
CPred<"$_op.getNumResults() > " # i>,
|
||||
CPred<"$_op.getNumOperands() > " # j>,
|
||||
SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
IsShapedTypePred>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
IsShapedTypePred>,
|
||||
// TODO: This could be made into C++ function instead.
|
||||
CPred<"$_op.getResult(" # i # ")->getType().cast<VectorOrTensorType>()."
|
||||
CPred<"$_op.getResult(" # i # ")->getType().cast<ShapedType>()."
|
||||
"getElementType() == $_op.getOperand(" # j # ")->getType()."
|
||||
"cast<VectorOrTensorType>().getElementType()">]>;
|
||||
"cast<ShapedType>().getElementType()">]>;
|
||||
|
||||
// Predicate to verify that all the operands at the given `indices`
|
||||
// have the same element type.
|
||||
// Type Constraint operands' Vector or Tensor Element type are all Same At
|
||||
// the given `indices`.
|
||||
// Type Constraint operands' Element type are all Same At the given `indices`.
|
||||
// We query the operands' types into a list and check they are all the same.
|
||||
// Precondition:
|
||||
// 1) all operands involved are of vector or tensor type and
|
||||
|
@ -1004,7 +1003,7 @@ class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
|
|||
class TCopVTEtAreSameAt<list<int> indices> :
|
||||
CPred<"llvm::is_splat(mlir::functional::map("
|
||||
"[this](unsigned i) { return this->getOperand(i)->getType()"
|
||||
".cast<VectorOrTensorType>().getElementType(); }, "
|
||||
".cast<ShapedType>().getElementType(); }, "
|
||||
"llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -591,7 +591,6 @@ public:
|
|||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand and result element type.
|
||||
///
|
||||
/// TODO: This only works for VectorOrTensorType at the moment.
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultElementType
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
|
||||
|
|
|
@ -35,7 +35,7 @@ class MLIRContext;
|
|||
namespace detail {
|
||||
|
||||
struct IntegerTypeStorage;
|
||||
struct VectorOrTensorTypeStorage;
|
||||
struct ShapedTypeStorage;
|
||||
struct VectorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
|
@ -180,11 +180,13 @@ public:
|
|||
const llvm::fltSemantics &getFloatSemantics();
|
||||
};
|
||||
|
||||
// TODO(b/132735995) Add support for MemRef
|
||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||
/// types, because many operations work on values of these aggregate types.
|
||||
class VectorOrTensorType : public Type {
|
||||
/// types because they share behavior and semantics around shape, rank, and
|
||||
/// fixed element type.
|
||||
class ShapedType : public Type {
|
||||
public:
|
||||
using ImplType = detail::VectorOrTensorTypeStorage;
|
||||
using ImplType = detail::ShapedTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Return the element type.
|
||||
|
@ -234,8 +236,8 @@ public:
|
|||
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||
/// known constant shape with one or more dimension.
|
||||
class VectorType : public Type::TypeBase<VectorType, VectorOrTensorType,
|
||||
detail::VectorTypeStorage> {
|
||||
class VectorType
|
||||
: public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
|
@ -268,9 +270,9 @@ public:
|
|||
|
||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||
/// RankedTensorType and UnrankedTensorType.
|
||||
class TensorType : public VectorOrTensorType {
|
||||
class TensorType : public ShapedType {
|
||||
public:
|
||||
using VectorOrTensorType::VectorOrTensorType;
|
||||
using ShapedType::ShapedType;
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidElementType(Type type) {
|
||||
|
|
|
@ -375,7 +375,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
|||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Value *aggregate,"
|
||||
"ArrayRef<Value *> indices = {}", [{
|
||||
auto resType = aggregate->getType().cast<VectorOrTensorType>()
|
||||
auto resType = aggregate->getType().cast<ShapedType>()
|
||||
.getElementType();
|
||||
build(builder, result, resType, aggregate, indices);
|
||||
}]>];
|
||||
|
|
|
@ -168,12 +168,12 @@ struct QuantizedMultiplierSmallerThanOneExp {
|
|||
|
||||
/// Casts an integer or floating point based type to a new element type.
|
||||
inline Type castElementType(Type t, Type newElementType) {
|
||||
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
|
||||
switch (vt.getKind()) {
|
||||
if (auto st = t.dyn_cast<ShapedType>()) {
|
||||
switch (st.getKind()) {
|
||||
case StandardTypes::Kind::Vector:
|
||||
return VectorType::get(vt.getShape(), newElementType);
|
||||
return VectorType::get(st.getShape(), newElementType);
|
||||
case StandardTypes::Kind::RankedTensor:
|
||||
return RankedTensorType::get(vt.getShape(), newElementType);
|
||||
return RankedTensorType::get(st.getShape(), newElementType);
|
||||
case StandardTypes::Kind::UnrankedTensor:
|
||||
return UnrankedTensorType::get(newElementType);
|
||||
}
|
||||
|
@ -185,10 +185,10 @@ inline Type castElementType(Type t, Type newElementType) {
|
|||
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
|
||||
/// be a primitive/vector/tensor).
|
||||
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
|
||||
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
|
||||
assert(vt.getElementType().isa<IntegerType>());
|
||||
return SplatElementsAttr::get(vt,
|
||||
IntegerAttr::get(vt.getElementType(), value));
|
||||
if (auto st = t.dyn_cast<ShapedType>()) {
|
||||
assert(st.getElementType().isa<IntegerType>());
|
||||
return SplatElementsAttr::get(st,
|
||||
IntegerAttr::get(st.getElementType(), value));
|
||||
}
|
||||
|
||||
auto integerType = t.cast<IntegerType>();
|
||||
|
@ -211,13 +211,13 @@ inline APFloat convertFloatToType(FloatType ft, APFloat value) {
|
|||
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
|
||||
/// be a primitive/vector/tensor).
|
||||
inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
|
||||
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
|
||||
FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>();
|
||||
if (auto st = t.dyn_cast<ShapedType>()) {
|
||||
FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
|
||||
assert(floatElementType &&
|
||||
"float broadcast element type must be float like");
|
||||
APFloat apValue = convertFloatToType(floatElementType, value);
|
||||
return SplatElementsAttr::get(vt,
|
||||
FloatAttr::get(vt.getElementType(), apValue));
|
||||
return SplatElementsAttr::get(st,
|
||||
FloatAttr::get(st.getElementType(), apValue));
|
||||
} else {
|
||||
auto floatType = t.dyn_cast<FloatType>();
|
||||
assert(floatType && "float broadcast must be of float type");
|
||||
|
|
|
@ -98,8 +98,8 @@ Type QuantizedType::getExpressedType() const {
|
|||
}
|
||||
|
||||
bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
|
||||
if (candidateExpressedType.isa<VectorOrTensorType>()) {
|
||||
return candidateExpressedType.cast<VectorOrTensorType>().getElementType() ==
|
||||
if (candidateExpressedType.isa<ShapedType>()) {
|
||||
return candidateExpressedType.cast<ShapedType>().getElementType() ==
|
||||
getExpressedType();
|
||||
}
|
||||
return candidateExpressedType == getExpressedType();
|
||||
|
@ -107,9 +107,9 @@ bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
|
|||
|
||||
QuantizedType
|
||||
QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
|
||||
if (primitiveOrContainerType.isa<VectorOrTensorType>()) {
|
||||
if (primitiveOrContainerType.isa<ShapedType>()) {
|
||||
Type elementType =
|
||||
primitiveOrContainerType.cast<VectorOrTensorType>().getElementType();
|
||||
primitiveOrContainerType.cast<ShapedType>().getElementType();
|
||||
return elementType.dyn_cast<QuantizedType>();
|
||||
}
|
||||
return primitiveOrContainerType.dyn_cast<QuantizedType>();
|
||||
|
@ -139,20 +139,20 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
|
|||
if (quantizedType.isa<QuantizedType>()) {
|
||||
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
|
||||
return quantizedType.cast<QuantizedType>().getStorageType();
|
||||
} else if (quantizedType.isa<VectorOrTensorType>()) {
|
||||
} else if (quantizedType.isa<ShapedType>()) {
|
||||
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||
VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
|
||||
if (!vtType.getElementType().isa<QuantizedType>()) {
|
||||
ShapedType sType = quantizedType.cast<ShapedType>();
|
||||
if (!sType.getElementType().isa<QuantizedType>()) {
|
||||
return nullptr;
|
||||
}
|
||||
Type storageType =
|
||||
vtType.getElementType().cast<QuantizedType>().getStorageType();
|
||||
sType.getElementType().cast<QuantizedType>().getStorageType();
|
||||
if (quantizedType.isa<RankedTensorType>()) {
|
||||
return RankedTensorType::get(vtType.getShape(), storageType);
|
||||
return RankedTensorType::get(sType.getShape(), storageType);
|
||||
} else if (quantizedType.isa<UnrankedTensorType>()) {
|
||||
return UnrankedTensorType::get(storageType);
|
||||
} else if (quantizedType.isa<VectorType>()) {
|
||||
return VectorType::get(vtType.getShape(), storageType);
|
||||
return VectorType::get(sType.getShape(), storageType);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,22 +163,21 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
|
|||
if (candidateType == getExpressedType()) {
|
||||
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
|
||||
return *this;
|
||||
} else if (candidateType.isa<VectorOrTensorType>()) {
|
||||
VectorOrTensorType candidateVtType =
|
||||
candidateType.cast<VectorOrTensorType>();
|
||||
if (candidateVtType.getElementType() != getExpressedType()) {
|
||||
} else if (candidateType.isa<ShapedType>()) {
|
||||
ShapedType candidateShapedType = candidateType.cast<ShapedType>();
|
||||
if (candidateShapedType.getElementType() != getExpressedType()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (candidateType.isa<RankedTensorType>()) {
|
||||
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||
return RankedTensorType::get(candidateVtType.getShape(), *this);
|
||||
return RankedTensorType::get(candidateShapedType.getShape(), *this);
|
||||
} else if (candidateType.isa<UnrankedTensorType>()) {
|
||||
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
|
||||
return UnrankedTensorType::get(*this);
|
||||
} else if (candidateType.isa<VectorType>()) {
|
||||
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||
return VectorType::get(candidateVtType.getShape(), *this);
|
||||
return VectorType::get(candidateShapedType.getShape(), *this);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,20 +188,20 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
|
|||
if (quantizedType.isa<QuantizedType>()) {
|
||||
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
|
||||
return quantizedType.cast<QuantizedType>().getExpressedType();
|
||||
} else if (quantizedType.isa<VectorOrTensorType>()) {
|
||||
} else if (quantizedType.isa<ShapedType>()) {
|
||||
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||
VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
|
||||
if (!vtType.getElementType().isa<QuantizedType>()) {
|
||||
ShapedType sType = quantizedType.cast<ShapedType>();
|
||||
if (!sType.getElementType().isa<QuantizedType>()) {
|
||||
return nullptr;
|
||||
}
|
||||
Type expressedType =
|
||||
vtType.getElementType().cast<QuantizedType>().getExpressedType();
|
||||
sType.getElementType().cast<QuantizedType>().getExpressedType();
|
||||
if (quantizedType.isa<RankedTensorType>()) {
|
||||
return RankedTensorType::get(vtType.getShape(), expressedType);
|
||||
return RankedTensorType::get(sType.getShape(), expressedType);
|
||||
} else if (quantizedType.isa<UnrankedTensorType>()) {
|
||||
return UnrankedTensorType::get(expressedType);
|
||||
} else if (quantizedType.isa<VectorType>()) {
|
||||
return VectorType::get(vtType.getShape(), expressedType);
|
||||
return VectorType::get(sType.getShape(), expressedType);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -56,10 +56,10 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
|
|||
|
||||
// Cast from an expressed-type-based type to storage-type-based type,
|
||||
// preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
|
||||
VectorOrTensorType newDenseType =
|
||||
ShapedType newDenseType =
|
||||
quantizedElementType
|
||||
.castExpressedToStorageType(realFPElementsAttr.getType())
|
||||
.dyn_cast_or_null<VectorOrTensorType>();
|
||||
.dyn_cast_or_null<ShapedType>();
|
||||
if (!newDenseType) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -87,9 +87,9 @@ convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
|
|||
|
||||
// Cast from an expressed-type-based type to storage-type-based type,
|
||||
// preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
|
||||
VectorOrTensorType newSplatType =
|
||||
ShapedType newSplatType =
|
||||
quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
|
||||
.dyn_cast_or_null<VectorOrTensorType>();
|
||||
.dyn_cast_or_null<ShapedType>();
|
||||
if (!newSplatType) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -116,9 +116,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
|
|||
|
||||
// Cast from an expressed-type-based type to storage-type-based type,
|
||||
// preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
|
||||
VectorOrTensorType newSparseType =
|
||||
ShapedType newSparseType =
|
||||
quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
|
||||
.dyn_cast_or_null<VectorOrTensorType>();
|
||||
.dyn_cast_or_null<ShapedType>();
|
||||
if (!newSparseType) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -38,13 +38,13 @@ ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
|
|||
case StandardTypes::RankedTensor:
|
||||
case StandardTypes::UnrankedTensor:
|
||||
case StandardTypes::Vector: {
|
||||
Type elementType = inputType.cast<VectorOrTensorType>().getElementType();
|
||||
Type elementType = inputType.cast<ShapedType>().getElementType();
|
||||
if (!isQuantizablePrimitiveType(elementType)) {
|
||||
// Unsupported.
|
||||
return ExpressedToUniformQuantizedConverter{inputType, nullptr};
|
||||
}
|
||||
return ExpressedToUniformQuantizedConverter{
|
||||
inputType, inputType.cast<VectorOrTensorType>().getElementType()};
|
||||
inputType, inputType.cast<ShapedType>().getElementType()};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,8 +80,8 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
|||
/// Returns the shape of the given type. Scalars will be considered as having a
|
||||
/// shape with zero dimensions.
|
||||
static ArrayRef<int64_t> getShape(Type type) {
|
||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||
return vtType.getShape();
|
||||
if (auto sType = type.dyn_cast<ShapedType>())
|
||||
return sType.getShape();
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -92,8 +92,8 @@ static ArrayRef<int64_t> getShape(Type type) {
|
|||
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
||||
// Returns the scalar type out of the given type.
|
||||
auto getScalarType = [](Type type) -> Type {
|
||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||
return vtType.getElementType();
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>())
|
||||
return shapedType.getElementType();
|
||||
return type;
|
||||
};
|
||||
|
||||
|
|
|
@ -315,7 +315,7 @@ static ValueHandle createBinaryHandle(
|
|||
return createBinaryHandle<IOp>(lhs, rhs);
|
||||
} else if (thisType.isa<FloatType>()) {
|
||||
return createBinaryHandle<FOp>(lhs, rhs);
|
||||
} else if (auto aggregateType = thisType.dyn_cast<VectorOrTensorType>()) {
|
||||
} else if (auto aggregateType = thisType.dyn_cast<ShapedType>()) {
|
||||
if (aggregateType.getElementType().isa<IntegerType>())
|
||||
return createBinaryHandle<IOp>(lhs, rhs);
|
||||
else if (aggregateType.getElementType().isa<FloatType>())
|
||||
|
|
|
@ -201,7 +201,7 @@ void ModuleState::visitType(Type type) {
|
|||
// Visit affine maps in memref type.
|
||||
for (auto map : memref.getAffineMaps())
|
||||
recordAttributeReference(AffineMapAttr::get(map));
|
||||
} else if (auto vecOrTensor = type.dyn_cast<VectorOrTensorType>()) {
|
||||
} else if (auto vecOrTensor = type.dyn_cast<ShapedType>()) {
|
||||
visitType(vecOrTensor.getElementType());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -313,8 +313,8 @@ FunctionType FunctionAttr::getType() const {
|
|||
// ElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
VectorOrTensorType ElementsAttr::getType() const {
|
||||
return Attribute::getType().cast<VectorOrTensorType>();
|
||||
ShapedType ElementsAttr::getType() const {
|
||||
return Attribute::getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
|
@ -339,8 +339,7 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
// SplatElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
|
||||
Attribute elt) {
|
||||
SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
|
||||
assert(elt.getType() == type.getElementType() &&
|
||||
"value should be of the given element type");
|
||||
return Base::get(type.getContext(), StandardAttributes::SplatElements, type,
|
||||
|
@ -374,8 +373,7 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const {
|
|||
// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||
ArrayRef<char> data) {
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
|
||||
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
|
||||
data.size() * APInt::APINT_WORD_SIZE) &&
|
||||
"Input data bit size should be larger than that type requires");
|
||||
|
@ -394,7 +392,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
}
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
assert(type.getElementType().isIntOrFloat() &&
|
||||
"expected int or float element type");
|
||||
|
@ -516,7 +514,7 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|||
// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
// Each APInt value is expected to have the same bitwidth as the element type
|
||||
// of 'type'.
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(values.size() == type.getNumElements() &&
|
||||
"expected 'values' to contain the same number of elements as 'type'");
|
||||
|
@ -586,7 +584,7 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'.
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
|
||||
}
|
||||
|
@ -594,7 +592,7 @@ DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
|
|||
/// Constructs a dense integer elements attribute from an array of integer
|
||||
/// values. Each value is expected to be within the bitwidth of the element
|
||||
/// type of 'type'.
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
||||
ArrayRef<int64_t> values) {
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
|
@ -625,7 +623,7 @@ DenseFPElementsAttr::ElementIterator::ElementIterator(
|
|||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type,
|
||||
DenseFPElementsAttr DenseFPElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APFloat> values) {
|
||||
// Convert the APFloat values to APInt and create a dense elements attribute.
|
||||
std::vector<APInt> intValues(values.size());
|
||||
|
@ -655,8 +653,7 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
|
|||
// OpaqueElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
|
||||
VectorOrTensorType type,
|
||||
OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
|
||||
StringRef bytes) {
|
||||
assert(TensorType::isValidElementType(type.getElementType()) &&
|
||||
"Input element type should be a valid tensor element type");
|
||||
|
@ -686,7 +683,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
|||
// SparseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
|
||||
SparseElementsAttr SparseElementsAttr::get(ShapedType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
assert(indices.getType().getElementType().isInteger(64) &&
|
||||
|
|
|
@ -173,34 +173,32 @@ FunctionAttr Builder::getFunctionAttr(Function *value) {
|
|||
return FunctionAttr::get(value);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
|
||||
Attribute elt) {
|
||||
ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
|
||||
return SplatElementsAttr::get(type, elt);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<char> data) {
|
||||
return DenseElementsAttr::get(type, data);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
return DenseElementsAttr::get(type, values);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseIntElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type,
|
||||
ArrayRef<int64_t> values) {
|
||||
return DenseIntElementsAttr::get(type, values);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
|
||||
ElementsAttr Builder::getSparseElementsAttr(ShapedType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
return SparseElementsAttr::get(type, indices, values);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect,
|
||||
VectorOrTensorType type,
|
||||
ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type,
|
||||
StringRef bytes) {
|
||||
return OpaqueElementsAttr::get(dialect, type, bytes);
|
||||
}
|
||||
|
@ -249,7 +247,7 @@ Attribute Builder::getZeroAttr(Type type) {
|
|||
}
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor: {
|
||||
auto vtType = type.cast<VectorOrTensorType>();
|
||||
auto vtType = type.cast<ShapedType>();
|
||||
auto element = getZeroAttr(vtType.getElementType());
|
||||
if (!element)
|
||||
return {};
|
||||
|
|
|
@ -775,8 +775,8 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
|
|||
}
|
||||
|
||||
/// Returns success if the given two types have the same shape. That is,
|
||||
/// they are both scalars, or they are both vectors / ranked tensors with
|
||||
/// the same dimension specifications. The element type does not matter.
|
||||
/// they are both scalars, or they are both static shaped types with the same
|
||||
/// dimension specifications. The element type does not matter.
|
||||
static LogicalResult verifyShapeMatch(Type type1, Type type2) {
|
||||
// Check scalar cases
|
||||
if (type1.isIntOrIndexOrFloat())
|
||||
|
@ -787,9 +787,9 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) {
|
|||
return failure();
|
||||
|
||||
// Check normal vector/tensor cases
|
||||
if (auto vtType1 = type1.dyn_cast<VectorOrTensorType>()) {
|
||||
auto vtType2 = type2.dyn_cast<VectorOrTensorType>();
|
||||
return success(vtType2 && vtType1.getShape() == vtType2.getShape());
|
||||
if (auto sType1 = type1.dyn_cast<ShapedType>()) {
|
||||
auto sType2 = type2.dyn_cast<ShapedType>();
|
||||
return success(sType2 && sType1.getShape() == sType2.getShape());
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -818,14 +818,14 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
|
|||
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
|
||||
return failure();
|
||||
|
||||
auto type = op->getResult(0)->getType().dyn_cast<VectorOrTensorType>();
|
||||
auto type = op->getResult(0)->getType().dyn_cast<ShapedType>();
|
||||
if (!type)
|
||||
return op->emitOpError("requires vector or tensor type results");
|
||||
auto elementType = type.getElementType();
|
||||
|
||||
// Verify result element type matches first result's element type.
|
||||
for (auto result : drop_begin(op->getResults(), 1)) {
|
||||
auto resultType = result->getType().dyn_cast<VectorOrTensorType>();
|
||||
auto resultType = result->getType().dyn_cast<ShapedType>();
|
||||
if (!resultType)
|
||||
return op->emitOpError("requires vector or tensor type results");
|
||||
if (resultType.getElementType() != elementType)
|
||||
|
@ -835,7 +835,7 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
|
|||
|
||||
// Verify operand's element type matches first result's element type.
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto operandType = operand->getType().dyn_cast<VectorOrTensorType>();
|
||||
auto operandType = operand->getType().dyn_cast<ShapedType>();
|
||||
if (!operandType)
|
||||
return op->emitOpError("requires vector or tensor type operands");
|
||||
if (operandType.getElementType() != elementType)
|
||||
|
|
|
@ -104,18 +104,18 @@ unsigned Type::getIntOrFloatBitWidth() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorOrTensorType
|
||||
// ShapedType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type VectorOrTensorType::getElementType() const {
|
||||
Type ShapedType::getElementType() const {
|
||||
return static_cast<ImplType *>(impl)->elementType;
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getElementTypeBitWidth() const {
|
||||
unsigned ShapedType::getElementTypeBitWidth() const {
|
||||
return getElementType().getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getNumElements() const {
|
||||
unsigned ShapedType::getNumElements() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor: {
|
||||
|
@ -127,13 +127,13 @@ unsigned VectorOrTensorType::getNumElements() const {
|
|||
return num;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
llvm_unreachable("not a ShapedType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int64_t VectorOrTensorType::getRank() const {
|
||||
int64_t ShapedType::getRank() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor:
|
||||
|
@ -141,24 +141,24 @@ int64_t VectorOrTensorType::getRank() const {
|
|||
case StandardTypes::UnrankedTensor:
|
||||
return -1;
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
llvm_unreachable("not a ShapedType");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t VectorOrTensorType::getDimSize(unsigned i) const {
|
||||
int64_t ShapedType::getDimSize(unsigned i) const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor:
|
||||
return getShape()[i];
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
llvm_unreachable("not a ShapedType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the number of number of bits require to store a value of the given vector
|
||||
// or tensor types. Compute the value recursively since tensors are allowed to
|
||||
// have vectors as elements.
|
||||
int64_t VectorOrTensorType::getSizeInBits() const {
|
||||
int64_t ShapedType::getSizeInBits() const {
|
||||
assert(hasStaticShape() &&
|
||||
"cannot get the bit size of an aggregate with a dynamic shape");
|
||||
|
||||
|
@ -168,23 +168,23 @@ int64_t VectorOrTensorType::getSizeInBits() const {
|
|||
|
||||
// Tensors can have vectors and other tensors as elements, vectors cannot.
|
||||
assert(!isa<VectorType>() && "unsupported vector element type");
|
||||
auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
|
||||
assert(elementVectorOrTensorType && "unsupported tensor element type");
|
||||
return getNumElements() * elementVectorOrTensorType.getSizeInBits();
|
||||
auto elementShapedType = elementType.dyn_cast<ShapedType>();
|
||||
assert(elementShapedType && "unsupported tensor element type");
|
||||
return getNumElements() * elementShapedType.getSizeInBits();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> VectorOrTensorType::getShape() const {
|
||||
ArrayRef<int64_t> ShapedType::getShape() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
return cast<VectorType>().getShape();
|
||||
case StandardTypes::RankedTensor:
|
||||
return cast<RankedTensorType>().getShape();
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
llvm_unreachable("not a ShapedType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
bool VectorOrTensorType::hasStaticShape() const {
|
||||
bool ShapedType::hasStaticShape() const {
|
||||
if (isa<UnrankedTensorType>())
|
||||
return false;
|
||||
return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
|
||||
|
|
|
@ -118,8 +118,8 @@ struct FunctionTypeStorage : public TypeStorage {
|
|||
};
|
||||
|
||||
/// VectorOrTensor Type Storage.
|
||||
struct VectorOrTensorTypeStorage : public TypeStorage {
|
||||
VectorOrTensorTypeStorage(Type elementType, unsigned subclassData = 0)
|
||||
struct ShapedTypeStorage : public TypeStorage {
|
||||
ShapedTypeStorage(Type elementType, unsigned subclassData = 0)
|
||||
: TypeStorage(subclassData), elementType(elementType) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
|
@ -130,11 +130,10 @@ struct VectorOrTensorTypeStorage : public TypeStorage {
|
|||
};
|
||||
|
||||
/// Vector Type Storage and Uniquing.
|
||||
struct VectorTypeStorage : public VectorOrTensorTypeStorage {
|
||||
struct VectorTypeStorage : public ShapedTypeStorage {
|
||||
VectorTypeStorage(unsigned shapeSize, Type elementTy,
|
||||
const int64_t *shapeElements)
|
||||
: VectorOrTensorTypeStorage(elementTy, shapeSize),
|
||||
shapeElements(shapeElements) {}
|
||||
: ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
|
||||
|
@ -160,11 +159,10 @@ struct VectorTypeStorage : public VectorOrTensorTypeStorage {
|
|||
const int64_t *shapeElements;
|
||||
};
|
||||
|
||||
struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage {
|
||||
struct RankedTensorTypeStorage : public ShapedTypeStorage {
|
||||
RankedTensorTypeStorage(unsigned shapeSize, Type elementTy,
|
||||
const int64_t *shapeElements)
|
||||
: VectorOrTensorTypeStorage(elementTy, shapeSize),
|
||||
shapeElements(shapeElements) {}
|
||||
: ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
|
||||
|
@ -190,9 +188,9 @@ struct RankedTensorTypeStorage : public VectorOrTensorTypeStorage {
|
|||
const int64_t *shapeElements;
|
||||
};
|
||||
|
||||
struct UnrankedTensorTypeStorage : public VectorOrTensorTypeStorage {
|
||||
using VectorOrTensorTypeStorage::KeyTy;
|
||||
using VectorOrTensorTypeStorage::VectorOrTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage : public ShapedTypeStorage {
|
||||
using ShapedTypeStorage::KeyTy;
|
||||
using ShapedTypeStorage::ShapedTypeStorage;
|
||||
|
||||
/// Construction.
|
||||
static UnrankedTensorTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
|
|
|
@ -200,9 +200,9 @@ public:
|
|||
// Polyhedral structures.
|
||||
ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
|
||||
IntegerSet &set);
|
||||
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
|
||||
DenseElementsAttr parseDenseElementsAttr(ShapedType type);
|
||||
DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType);
|
||||
VectorOrTensorType parseVectorOrTensorType();
|
||||
ShapedType parseShapedType();
|
||||
|
||||
// Location Parsing.
|
||||
|
||||
|
@ -1255,7 +1255,7 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseVectorOrTensorType();
|
||||
auto type = parseShapedType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
|
@ -1279,7 +1279,7 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseVectorOrTensorType();
|
||||
auto type = parseShapedType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
switch (getToken().getKind()) {
|
||||
|
@ -1305,7 +1305,7 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseVectorOrTensorType();
|
||||
auto type = parseShapedType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
|
@ -1323,7 +1323,7 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
|
||||
return nullptr;
|
||||
|
||||
auto type = parseVectorOrTensorType();
|
||||
auto type = parseShapedType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
|
||||
|
@ -1414,7 +1414,7 @@ DenseElementsAttr Parser::parseDenseElementsAttrAsTensor(Type eltType) {
|
|||
/// This method compares the shapes from the parsing result and that from the
|
||||
/// input argument. It returns a constructed dense elements attribute if both
|
||||
/// match.
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
|
||||
auto eltTy = type.getElementType();
|
||||
TensorLiteralParser literalParser(*this, eltTy);
|
||||
if (literalParser.parse())
|
||||
|
@ -1431,27 +1431,26 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
|
|||
.cast<DenseElementsAttr>();
|
||||
}
|
||||
|
||||
/// Vector or tensor type for elements attribute.
|
||||
/// Shaped type for elements attribute.
|
||||
///
|
||||
/// vector-or-tensor-type ::= vector-type | tensor-type
|
||||
/// shaped-type ::= vector-type | tensor-type
|
||||
///
|
||||
/// This method also checks the type has static shape and ranked.
|
||||
VectorOrTensorType Parser::parseVectorOrTensorType() {
|
||||
ShapedType Parser::parseShapedType() {
|
||||
auto elementType = parseType();
|
||||
if (!elementType)
|
||||
return nullptr;
|
||||
|
||||
auto type = elementType.dyn_cast<VectorOrTensorType>();
|
||||
auto type = elementType.dyn_cast<ShapedType>();
|
||||
if (!type) {
|
||||
return (emitError("expected elements literal has a tensor or vector type"),
|
||||
nullptr);
|
||||
return (emitError("elements literal must be a shaped type"), nullptr);
|
||||
}
|
||||
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
return nullptr;
|
||||
|
||||
if (!type.hasStaticShape() || type.getRank() == -1) {
|
||||
return (emitError("tensor literals must be ranked and have static shape"),
|
||||
return (emitError("shaped literal must be ranked and have static shape"),
|
||||
nullptr);
|
||||
}
|
||||
return type;
|
||||
|
|
|
@ -66,18 +66,18 @@ static bool getElementsStatistics(ElementsAttr attr,
|
|||
statistics.minValue = std::numeric_limits<double>::infinity();
|
||||
statistics.maxValue = -std::numeric_limits<double>::infinity();
|
||||
|
||||
VectorOrTensorType vtType = attr.getType();
|
||||
if (!vtType.hasStaticShape())
|
||||
ShapedType sType = attr.getType();
|
||||
if (!sType.hasStaticShape())
|
||||
return false;
|
||||
Type elementTy = vtType.getElementType();
|
||||
Type elementTy = sType.getElementType();
|
||||
if (!elementTy.isa<FloatType>())
|
||||
return false;
|
||||
|
||||
llvm::SmallVector<uint64_t, 4> indices;
|
||||
indices.resize(vtType.getRank());
|
||||
ArrayRef<int64_t> shape = vtType.getShape();
|
||||
indices.resize(sType.getRank());
|
||||
ArrayRef<int64_t> shape = sType.getShape();
|
||||
|
||||
auto numElements = vtType.getNumElements();
|
||||
auto numElements = sType.getNumElements();
|
||||
collectElementsStatisticsDim(attr, numElements, shape, indices, 0,
|
||||
statistics);
|
||||
statistics.sampleSize = numElements;
|
||||
|
|
|
@ -23,8 +23,8 @@ using namespace mlir;
|
|||
using namespace mlir::quantizer;
|
||||
|
||||
Type mlir::quantizer::getElementOrPrimitiveType(Type t) {
|
||||
if (auto vtType = t.dyn_cast<VectorOrTensorType>()) {
|
||||
return vtType.getElementType();
|
||||
if (auto sType = t.dyn_cast<ShapedType>()) {
|
||||
return sType.getElementType();
|
||||
} else {
|
||||
return t;
|
||||
}
|
||||
|
|
|
@ -1163,9 +1163,9 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
if (type.isa<VectorOrTensorType>()) {
|
||||
if (type.isa<ShapedType>()) {
|
||||
if (!value.isa<ElementsAttr>())
|
||||
return op.emitOpError("requires 'value' to be a vector/tensor constant");
|
||||
return op.emitOpError("requires 'value' to be a shaped constant");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1639,7 +1639,7 @@ static ParseResult parseExtractElementOp(OpAsmParser *parser,
|
|||
OperationState *result) {
|
||||
OpAsmParser::OperandType aggregateInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
VectorOrTensorType type;
|
||||
ShapedType type;
|
||||
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return failure(
|
||||
|
@ -1656,8 +1656,7 @@ static LogicalResult verify(ExtractElementOp op) {
|
|||
if (op.getNumOperands() == 0)
|
||||
return op.emitOpError("expected an aggregate to index into");
|
||||
|
||||
auto aggregateType =
|
||||
op.getAggregate()->getType().dyn_cast<VectorOrTensorType>();
|
||||
auto aggregateType = op.getAggregate()->getType().dyn_cast<ShapedType>();
|
||||
if (!aggregateType)
|
||||
return op.emitOpError("first operand must be a vector or tensor");
|
||||
|
||||
|
|
|
@ -599,14 +599,14 @@ func@n(){^b(
|
|||
|
||||
func @elementsattr_non_tensor_type() -> () {
|
||||
^bb0:
|
||||
"foo"(){bar: dense<i32, [4]>} : () -> () // expected-error {{expected elements literal has a tensor or vector type}}
|
||||
"foo"(){bar: dense<i32, [4]>} : () -> () // expected-error {{elements literal must be a shaped type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @elementsattr_non_ranked() -> () {
|
||||
^bb0:
|
||||
"foo"(){bar: dense<tensor<?xi32>, [4]>} : () -> () // expected-error {{tensor literals must be ranked and have static shape}}
|
||||
"foo"(){bar: dense<tensor<?xi32>, [4]>} : () -> () // expected-error {{shaped literal must be ranked and have static shape}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -33,7 +33,7 @@ def OpC : NS_Op<"op_for_TCopVTEtIs", [
|
|||
}
|
||||
|
||||
// CHECK-LABEL: OpC::verify
|
||||
// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa<VectorOrTensorType>())) && (((*this->getOperation()).getOperand(0)->getType().cast<VectorOrTensorType>().getElementType().isInteger(32)))))
|
||||
// CHECK: if (!((((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getOperand(0)->getType().isa<ShapedType>())) && (((*this->getOperation()).getOperand(0)->getType().cast<ShapedType>().getElementType().isInteger(32)))))
|
||||
|
||||
|
||||
def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [
|
||||
|
@ -44,7 +44,7 @@ def OpD : NS_Op<"op_for_TCOpVTEtIsSameAs", [
|
|||
}
|
||||
|
||||
// CHECK-LABEL: OpD::verify
|
||||
// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa<VectorOrTensorType>())) && (((*this->getOperation()).getOperand(0)->getType().isa<VectorOrTensorType>())) && (((*this->getOperation()).getOperand(0)->getType().cast<VectorOrTensorType>().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast<VectorOrTensorType>().getElementType()))))
|
||||
// CHECK: if (!((((*this->getOperation()).getNumOperands() > std::max(0,0))) && (((*this->getOperation()).getOperand(0)->getType().isa<ShapedType>())) && (((*this->getOperation()).getOperand(0)->getType().isa<ShapedType>())) && (((*this->getOperation()).getOperand(0)->getType().cast<ShapedType>().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast<ShapedType>().getElementType()))))
|
||||
// CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as itself");
|
||||
|
||||
|
||||
|
@ -57,7 +57,7 @@ def OpE : NS_Op<"op_for_TCresVTEtIsSameAsOp", [
|
|||
}
|
||||
|
||||
// CHECK-LABEL: OpE::verify
|
||||
// CHECK: if (!((((*this->getOperation()).getNumResults() > 0)) && (((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getResult(0)->getType().isa<VectorOrTensorType>())) && (((*this->getOperation()).getOperand(0)->getType().isa<VectorOrTensorType>())) && (((*this->getOperation()).getResult(0)->getType().cast<VectorOrTensorType>().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast<VectorOrTensorType>().getElementType()))))
|
||||
// CHECK: if (!((((*this->getOperation()).getNumResults() > 0)) && (((*this->getOperation()).getNumOperands() > 0)) && (((*this->getOperation()).getResult(0)->getType().isa<ShapedType>())) && (((*this->getOperation()).getOperand(0)->getType().isa<ShapedType>())) && (((*this->getOperation()).getResult(0)->getType().cast<ShapedType>().getElementType() == (*this->getOperation()).getOperand(0)->getType().cast<ShapedType>().getElementType()))))
|
||||
// CHECK-NEXT: return emitOpError("failed to verify that first operand is a vector or tensor with the same elemental type as first result");
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
|
|||
|
||||
// CHECK-LABEL: OpJ::verify()
|
||||
// CHECK: llvm::is_splat(mlir::functional::map(
|
||||
// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast<VectorOrTensorType>().getElementType(); },
|
||||
// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast<ShapedType>().getElementType(); },
|
||||
// CHECK-SAME: llvm::ArrayRef<unsigned>({0, 2, 3})))
|
||||
// CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ template <typename ConcreteAttrClass, typename... Arg>
|
|||
ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
|
||||
Arg... value) {
|
||||
auto eleType = FloatType::getF32(ctx);
|
||||
VectorOrTensorType tensorType;
|
||||
ShapedType tensorType;
|
||||
if (shape.size() == 1 && shape[0] == -1) {
|
||||
tensorType = UnrankedTensorType::get(eleType);
|
||||
} else {
|
||||
|
@ -61,7 +61,7 @@ ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
|
|||
ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
|
||||
ArrayRef<int64_t> shape) {
|
||||
auto eleType = FloatType::getF32(ctx);
|
||||
VectorOrTensorType tensorType;
|
||||
ShapedType tensorType;
|
||||
if (shape.size() == 1 && shape[0] == -1) {
|
||||
tensorType = UnrankedTensorType::get(eleType);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue