Add iterator support to DenseIntElementsAttr and DenseFPElementsAttr. This avoids the need to load all of the values from a DenseElementsAttr inorder to process them.

--

PiperOrigin-RevId: 242212741
This commit is contained in:
River Riddle 2019-04-05 16:11:24 -07:00 committed by Mehdi Amini
parent fe1211edf2
commit 67653d9881
3 changed files with 182 additions and 66 deletions

View File

@ -387,6 +387,9 @@ public:
static DenseElementsAttr get(VectorOrTensorType type,
ArrayRef<Attribute> values);
/// Returns the number of elements held by this attribute.
size_t size() const;
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
@ -409,20 +412,69 @@ public:
}
protected:
/// A utility iterator that allows walking over the internal raw APInt values.
class RawElementIterator
: public llvm::iterator_facade_base<RawElementIterator,
std::bidirectional_iterator_tag,
APInt, std::ptrdiff_t, APInt, APInt> {
public:
/// Iterator movement.
RawElementIterator &operator++() {
++index;
return *this;
}
RawElementIterator &operator--() {
--index;
return *this;
}
/// Accesses the raw APInt value at this iterator position.
APInt operator*() const;
/// Iterator equality.
bool operator==(const RawElementIterator &rhs) const {
return rawData == rhs.rawData && index == rhs.index;
}
bool operator!=(const RawElementIterator &rhs) const {
return !(*this == rhs);
}
private:
friend DenseElementsAttr;
/// Constructs a new iterator.
RawElementIterator(DenseElementsAttr attr, size_t index);
/// The base address of the raw data buffer.
const char *rawData;
/// The current element index.
size_t index;
/// The bitwidth of the element type.
size_t bitWidth;
};
/// Raw element iterators for this attribute.
RawElementIterator raw_begin() const { return RawElementIterator(*this, 0); }
RawElementIterator raw_end() const {
return RawElementIterator(*this, size());
}
// Constructs a dense elements attribute from an array of raw APInt values.
// Each APInt value is expected to have the same bitwidth as the element type
// of 'type'.
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<APInt> values);
/// Parses the raw integer internal value for each dense element into
/// 'values'.
void getRawValues(SmallVectorImpl<APInt> &values) const;
};
/// An attribute that represents a reference to a dense integer vector or tensor
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
public:
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
/// iterator directly.
using iterator = DenseElementsAttr::RawElementIterator;
using DenseElementsAttr::DenseElementsAttr;
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
@ -443,6 +495,10 @@ public:
/// Gets the integer value of each of the dense elements.
void getValues(SmallVectorImpl<APInt> &values) const;
/// Iterator access to the integer element values.
iterator begin() const { return raw_begin(); }
iterator end() const { return raw_end(); }
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
};
@ -451,6 +507,18 @@ public:
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
public:
/// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
/// element iterator.
class ElementIterator final
: public llvm::mapped_iterator<RawElementIterator,
std::function<APFloat(const APInt &)>> {
friend DenseFPElementsAttr;
/// Initializes the float element iterator to the specified iterator.
ElementIterator(const llvm::fltSemantics &smt, RawElementIterator it);
};
using iterator = ElementIterator;
using DenseElementsAttr::DenseElementsAttr;
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
@ -465,6 +533,10 @@ public:
/// Gets the float value of each of the dense elements.
void getValues(SmallVectorImpl<APFloat> &values) const;
/// Iterator access to the float element values.
iterator begin() const;
iterator end() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::DenseFPElements; }
};

View File

@ -65,7 +65,9 @@ Attribute Attribute::remapFunctionAttrs(
return ArrayAttr::get(remappedElts, context);
}
/// NumericAttr
//===----------------------------------------------------------------------===//
// NumericAttr
//===----------------------------------------------------------------------===//
Type NumericAttr::getType() const {
if (auto boolAttr = dyn_cast<BoolAttr>())
@ -85,13 +87,17 @@ bool NumericAttr::kindof(Kind kind) {
FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
}
/// BoolAttr
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
/// IntegerAttr
//===----------------------------------------------------------------------===//
// IntegerAttr
//===----------------------------------------------------------------------===//
APInt IntegerAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
@ -103,7 +109,9 @@ Type IntegerAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
/// FloatAttr
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
APFloat FloatAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
@ -122,35 +130,47 @@ double FloatAttr::getValueAsDouble() const {
return value.convertToDouble();
}
/// StringAttr
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
StringRef StringAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
/// ArrayAttr
//===----------------------------------------------------------------------===//
// ArrayAttr
//===----------------------------------------------------------------------===//
ArrayRef<Attribute> ArrayAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
/// AffineMapAttr
//===----------------------------------------------------------------------===//
// AffineMapAttr
//===----------------------------------------------------------------------===//
AffineMap AffineMapAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
/// IntegerSetAttr
//===----------------------------------------------------------------------===//
// IntegerSetAttr
//===----------------------------------------------------------------------===//
IntegerSet IntegerSetAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
/// TypeAttr
//===----------------------------------------------------------------------===//
// TypeAttr
//===----------------------------------------------------------------------===//
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
/// FunctionAttr
//===----------------------------------------------------------------------===//
// FunctionAttr
//===----------------------------------------------------------------------===//
Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
@ -158,7 +178,9 @@ Function *FunctionAttr::getValue() const {
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
/// ElementsAttr
//===----------------------------------------------------------------------===//
// ElementsAttr
//===----------------------------------------------------------------------===//
VectorOrTensorType ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
@ -182,13 +204,41 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
}
}
/// SplatElementsAttr
//===----------------------------------------------------------------------===//
// SplatElementsAttr
//===----------------------------------------------------------------------===//
Attribute SplatElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->elt;
}
/// DenseElementsAttr
//===----------------------------------------------------------------------===//
// RawElementIterator
//===----------------------------------------------------------------------===//
static size_t getDenseElementBitwidth(Type eltType) {
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
}
/// Constructs a new iterator.
DenseElementsAttr::RawElementIterator::RawElementIterator(
DenseElementsAttr attr, size_t index)
: rawData(attr.getRawData().data()), index(index),
bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
/// Accesses the raw APInt value at this iterator position.
APInt DenseElementsAttr::RawElementIterator::operator*() const {
return readBits(rawData, index * bitWidth, bitWidth);
}
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//
/// Returns the number of elements held by this attribute.
size_t DenseElementsAttr::size() const { return getType().getNumElements(); }
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
@ -215,12 +265,8 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
}
// Return the element stored at the 1D index.
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
auto elementType = getType().getElementType();
size_t bitWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
size_t bitWidth = getDenseElementBitwidth(elementType);
APInt rawValueData =
readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
@ -277,10 +323,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
assert(values.size() == type.getNumElements() &&
"expected 'values' to contain the same number of elements as 'type'");
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
auto eltType = type.getElementType();
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
size_t bitWidth = getDenseElementBitwidth(type.getElementType());
std::vector<char> elementData(APInt::getNumWords(bitWidth * values.size()) *
APInt::APINT_WORD_SIZE);
for (unsigned i = 0, e = values.size(); i != e; ++i) {
@ -290,22 +333,6 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
return get(type, elementData);
}
/// Parses the raw integer internal value for each dense element into
/// 'values'.
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
auto elementType = getType().getElementType();
auto elementNum = getType().getNumElements();
values.reserve(elementNum);
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
size_t bitWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
const auto *rawData = getRawData().data();
for (size_t i = 0, e = elementNum; i != e; ++i)
values.push_back(readBits(rawData, i * bitWidth, bitWidth));
}
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
/// expected to be a 64-bit aligned storage address.
void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) {
@ -354,7 +381,9 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
return result;
}
/// DenseIntElementsAttr
//===----------------------------------------------------------------------===//
// DenseIntElementsAttr
//===----------------------------------------------------------------------===//
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
@ -381,11 +410,19 @@ DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
}
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
// Simply return the raw integer values.
getRawValues(values);
values.reserve(size());
values.assign(raw_begin(), raw_end());
}
/// DenseFPElementsAttr
//===----------------------------------------------------------------------===//
// DenseFPElementsAttr
//===----------------------------------------------------------------------===//
DenseFPElementsAttr::ElementIterator::ElementIterator(
const llvm::fltSemantics &smt, RawElementIterator it)
: llvm::mapped_iterator<RawElementIterator,
std::function<APFloat(const APInt &)>>(
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
@ -400,18 +437,25 @@ DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type,
}
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
// Get the raw APInt element values.
SmallVector<APInt, 8> intValues;
getRawValues(intValues);
// Convert each of the APInt values to an APFloat.
auto elementType = getType().getElementType().dyn_cast<FloatType>();
const auto &elementSemantics = elementType.getFloatSemantics();
for (auto &intValue : intValues)
values.push_back(APFloat(elementSemantics, intValue));
values.reserve(size());
values.assign(begin(), end());
}
/// OpaqueElementsAttr
/// Iterator access to the float element values.
DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
auto elementType = getType().getElementType().cast<FloatType>();
const auto &elementSemantics = elementType.getFloatSemantics();
return {elementSemantics, raw_begin()};
}
DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
auto elementType = getType().getElementType().cast<FloatType>();
const auto &elementSemantics = elementType.getFloatSemantics();
return {elementSemantics, raw_end()};
}
//===----------------------------------------------------------------------===//
// OpaqueElementsAttr
//===----------------------------------------------------------------------===//
StringRef OpaqueElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->bytes;
@ -435,7 +479,9 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
return true;
}
/// SparseElementsAttr
//===----------------------------------------------------------------------===//
// SparseElementsAttr
//===----------------------------------------------------------------------===//
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
return static_cast<ImplType *>(attr)->indices;
@ -482,7 +528,9 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
return getValues().getValue(it->second);
}
/// NamedAttributeList
//===----------------------------------------------------------------------===//
// NamedAttributeList
//===----------------------------------------------------------------------===//
NamedAttributeList::NamedAttributeList(MLIRContext *context,
ArrayRef<NamedAttribute> attributes) {

View File

@ -47,15 +47,11 @@ static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
// Read real expressed values.
SmallVector<APFloat, 8> realValues;
realValues.reserve(realFPElementsAttr.getType().getNumElements());
realFPElementsAttr.getValues(realValues);
// Convert to corresponding quantized value attributes.
SmallVector<APInt, 8> quantValues(realValues.size());
for (size_t i = 0, e = realValues.size(); i < e; ++i) {
quantValues[i] = converter.quantizeFloatToInt(realValues[i]);
SmallVector<APInt, 8> quantValues;
quantValues.reserve(realFPElementsAttr.size());
for (APFloat realVal : realFPElementsAttr) {
quantValues.push_back(converter.quantizeFloatToInt(realVal));
}
// Cast from an expressed-type-based type to storage-type-based type,