Add a new AttributeElementIterator to DenseElementsAttr.

This allows for iterating over the internal elements via an iterator_range of Attribute, and also allows for removing the final SmallVectorImpl based 'getValues' method.

PiperOrigin-RevId: 255309555
This commit is contained in:
River Riddle 2019-06-26 18:49:50 -07:00 committed by A. Unique TensorFlower
parent 260d3e39ad
commit 6ebd6df69f
3 changed files with 85 additions and 35 deletions

View File

@ -547,6 +547,22 @@ public:
// Iterators
//===--------------------------------------------------------------------===//
/// A utility iterator that allows walking over the internal Attribute values
/// of a DenseElementsAttr.
class AttributeElementIterator
: public indexed_accessor_iterator<AttributeElementIterator, const void *,
Attribute, Attribute, Attribute> {
public:
/// Accesses the Attribute value at this iterator position.
Attribute operator*() const;
private:
friend DenseElementsAttr;
/// Constructs a new iterator.
AttributeElementIterator(DenseElementsAttr attr, size_t index);
};
/// A utility iterator that allows walking over the internal raw APInt values.
class IntElementIterator
: public indexed_accessor_iterator<IntElementIterator, const char *,
@ -597,9 +613,6 @@ public:
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Return the held element values as Attributes in 'values'.
void getValues(SmallVectorImpl<Attribute> &values) const;
/// Return the held element values as an array of integer or floating-point
/// values.
template <typename T, typename = typename std::enable_if<
@ -613,6 +626,16 @@ public:
rawData.size() / sizeof(T));
}
/// Return the held element values as a range of Attributes.
llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, Attribute>::value>::type>
llvm::iterator_range<AttributeElementIterator> getValues() const {
return getAttributeValues();
}
AttributeElementIterator attr_value_begin() const;
AttributeElementIterator attr_value_end() const;
/// Return the held element values as a range of APInts. The element type of
/// this attribute must be of integer type.
llvm::iterator_range<IntElementIterator> getIntValues() const;
@ -621,6 +644,8 @@ public:
llvm::iterator_range<IntElementIterator> getValues() const {
return getIntValues();
}
IntElementIterator int_value_begin() const;
IntElementIterator int_value_end() const;
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
@ -630,6 +655,8 @@ public:
llvm::iterator_range<FloatElementIterator> getValues() const {
return getFloatValues();
}
FloatElementIterator float_value_begin() const;
FloatElementIterator float_value_end() const;
//===--------------------------------------------------------------------===//
// Mutation Utilities
@ -704,8 +731,8 @@ public:
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Iterator access to the float element values.
iterator begin() const { return getFloatValues().begin(); }
iterator end() const { return getFloatValues().end(); }
iterator begin() const { return float_value_begin(); }
iterator end() const { return float_value_end(); }
/// Method for supporting type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr);

View File

@ -460,6 +460,31 @@ static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
(type.getNumElements() == static_cast<int64_t>(values.size()));
}
//===----------------------------------------------------------------------===//
// DenseElementAttr Iterators
//===----------------------------------------------------------------------===//
/// Constructs a new iterator.
DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
DenseElementsAttr attr, size_t index)
: indexed_accessor_iterator<AttributeElementIterator, const void *,
Attribute, Attribute, Attribute>(
attr.getAsOpaquePointer(), index) {}
/// Accesses the Attribute value at this iterator position.
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>();
Type eltTy = owner.getType().getElementType();
if (eltTy.isa<IntegerType>())
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
IntElementIterator intIt(owner, index);
FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
return FloatAttr::get(eltTy, *floatIt);
}
llvm_unreachable("unexpected element type");
}
/// Constructs a new iterator.
DenseElementsAttr::IntElementIterator::IntElementIterator(
DenseElementsAttr attr, size_t index)
@ -636,16 +661,7 @@ bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
/// If this attribute corresponds to a splat, then get the splat value.
/// Otherwise, return null.
Attribute DenseElementsAttr::getSplatValue() const {
if (!isSplat())
return Attribute();
auto elementType = getType().getElementType();
if (elementType.isa<IntegerType>())
return IntegerAttr::get(elementType, *raw_int_begin());
if (auto fType = elementType.dyn_cast<FloatType>())
return FloatAttr::get(elementType,
APFloat(fType.getFloatSemantics(), *raw_int_begin()));
llvm_unreachable("unexpected element type");
return isSplat() ? *attr_value_begin() : Attribute();
}
/// Return the value at the given index. If index does not refer to a valid
@ -692,23 +708,16 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
llvm_unreachable("unexpected element type");
}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
values.reserve(rawSize());
auto elementType = getType().getElementType();
if (elementType.isa<IntegerType>()) {
// Convert each value to an IntegerAttr.
for (auto intVal : getIntValues())
values.push_back(IntegerAttr::get(elementType, intVal));
return;
}
if (elementType.isa<FloatType>()) {
// Convert each value to a FloatAttr.
for (auto floatVal : getFloatValues())
values.push_back(FloatAttr::get(elementType, floatVal));
return;
}
llvm_unreachable("unexpected element type");
/// Return the held element values as a range of Attributes.
auto DenseElementsAttr::getAttributeValues() const
-> llvm::iterator_range<AttributeElementIterator> {
return {attr_value_begin(), attr_value_end()};
}
auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
return AttributeElementIterator(*this, 0);
}
auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
return AttributeElementIterator(*this, rawSize());
}
/// Return the held element values as a range of APInts. The element type of
@ -719,6 +728,16 @@ auto DenseElementsAttr::getIntValues() const
"expected integer type");
return {raw_int_begin(), raw_int_end()};
}
auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
assert(getType().getElementType().isa<IntegerType>() &&
"expected integer type");
return raw_int_begin();
}
auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
assert(getType().getElementType().isa<IntegerType>() &&
"expected integer type");
return raw_int_end();
}
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
@ -730,6 +749,12 @@ auto DenseElementsAttr::getFloatValues() const
return {FloatElementIterator(elementSemantics, raw_int_begin()),
FloatElementIterator(elementSemantics, raw_int_end())};
}
auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
return getFloatValues().begin();
}
auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
return getFloatValues().end();
}
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the

View File

@ -96,9 +96,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
SmallVector<llvm::Constant *, 8> constants;
uint64_t numElements = vectorType->getNumElements();
constants.reserve(numElements);
SmallVector<Attribute, 8> nested;
denseAttr.getValues(nested);
for (auto n : nested) {
for (auto n : denseAttr.getAttributeValues()) {
constants.push_back(
getLLVMConstant(vectorType->getElementType(), n, loc));
if (!constants.back())