Make it clear that ElementsAttr is only for static shaped vectors or tensors.

This is in preparation for making MemRef a subclass of ShapedType, but also UnrankedTensor should already be excluded.

--

PiperOrigin-RevId: 250580197
This commit is contained in:
Geoffrey Martin-Noble 2019-05-29 15:34:50 -07:00 committed by Mehdi Amini
parent 5a91b9896c
commit 66e84bf88c
2 changed files with 30 additions and 12 deletions

View File

@ -390,11 +390,14 @@ public:
}
};
/// A base attribute that represents a reference to a vector or tensor constant.
/// A base attribute that represents a reference to a static shaped tensor or
/// vector constant.
class ElementsAttr : public Attribute {
public:
using Attribute::Attribute;
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
ShapedType getType() const;
/// Return the value at the given index. If index does not refer to a valid
@ -431,6 +434,7 @@ public:
using Base::Base;
using ValueType = Attribute;
/// 'type' must be a vector or tensor with static shape.
static SplatElementsAttr get(ShapedType type, Attribute elt);
Attribute getValue() const;
@ -462,11 +466,13 @@ public:
using ImplType = detail::DenseElementsAttributeStorage;
/// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type.
/// width specified by the element type. 'type' must be a vector or tensor
/// with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
// Constructs a dense elements attribute from an array of element values. Each
// element attribute value is expected to be an element of 'type'.
/// Constructs a dense elements attribute from an array of element values.
/// Each element attribute value is expected to be an element of 'type'.
/// 'type' must be a vector or tensor with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
/// Returns the number of elements held by this attribute.
@ -558,9 +564,9 @@ protected:
return RawElementIterator(*this, size());
}
// Constructs a dense elements attribute from an array of raw APInt values.
// Each APInt value is expected to have the same bitwidth as the element type
// of 'type'.
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
};
@ -580,12 +586,13 @@ public:
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'.
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
/// Constructs a dense integer elements attribute from an array of integer
/// values. Each value is expected to be within the bitwidth of the element
/// type of 'type'.
/// type of 'type'. 'type' must be a vector or tensor with static shape.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
/// Generates a new DenseElementsAttr by mapping each value attribute, and
@ -629,9 +636,10 @@ public:
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
// element type of 'type'.
/// Constructs a dense float elements attribute from an array of APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseFPElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
/// Gets the float value of each of the dense elements.
@ -712,6 +720,7 @@ class SparseElementsAttr
public:
using Base::Base;
/// 'type' must be a vector or tensor with static shape.
static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices,
DenseElementsAttr values);

View File

@ -322,6 +322,9 @@ ElementsAttr ElementsAttr::mapValues(
SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
assert(elt.getType() == type.getElementType() &&
"value should be of the given element type");
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::SplatElements, type,
elt);
}
@ -424,6 +427,9 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
data.size() * APInt::APINT_WORD_SIZE) &&
"Input data bit size should be larger than that type requires");
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
switch (type.getElementType().getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
@ -797,6 +803,9 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
indices, values);
}