forked from OSchip/llvm-project
Add iterator support to DenseIntElementsAttr and DenseFPElementsAttr. This avoids the need to load all of the values from a DenseElementsAttr inorder to process them.
-- PiperOrigin-RevId: 242212741
This commit is contained in:
parent
fe1211edf2
commit
67653d9881
|
@ -387,6 +387,9 @@ public:
|
|||
static DenseElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values);
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
size_t size() const;
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
@ -409,20 +412,69 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
/// A utility iterator that allows walking over the internal raw APInt values.
|
||||
class RawElementIterator
|
||||
: public llvm::iterator_facade_base<RawElementIterator,
|
||||
std::bidirectional_iterator_tag,
|
||||
APInt, std::ptrdiff_t, APInt, APInt> {
|
||||
public:
|
||||
/// Iterator movement.
|
||||
RawElementIterator &operator++() {
|
||||
++index;
|
||||
return *this;
|
||||
}
|
||||
RawElementIterator &operator--() {
|
||||
--index;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Accesses the raw APInt value at this iterator position.
|
||||
APInt operator*() const;
|
||||
|
||||
/// Iterator equality.
|
||||
bool operator==(const RawElementIterator &rhs) const {
|
||||
return rawData == rhs.rawData && index == rhs.index;
|
||||
}
|
||||
bool operator!=(const RawElementIterator &rhs) const {
|
||||
return !(*this == rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
friend DenseElementsAttr;
|
||||
|
||||
/// Constructs a new iterator.
|
||||
RawElementIterator(DenseElementsAttr attr, size_t index);
|
||||
|
||||
/// The base address of the raw data buffer.
|
||||
const char *rawData;
|
||||
|
||||
/// The current element index.
|
||||
size_t index;
|
||||
|
||||
/// The bitwidth of the element type.
|
||||
size_t bitWidth;
|
||||
};
|
||||
|
||||
/// Raw element iterators for this attribute.
|
||||
RawElementIterator raw_begin() const { return RawElementIterator(*this, 0); }
|
||||
RawElementIterator raw_end() const {
|
||||
return RawElementIterator(*this, size());
|
||||
}
|
||||
|
||||
// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
// Each APInt value is expected to have the same bitwidth as the element type
|
||||
// of 'type'.
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Parses the raw integer internal value for each dense element into
|
||||
/// 'values'.
|
||||
void getRawValues(SmallVectorImpl<APInt> &values) const;
|
||||
};
|
||||
|
||||
/// An attribute that represents a reference to a dense integer vector or tensor
|
||||
/// object.
|
||||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
|
||||
/// iterator directly.
|
||||
using iterator = DenseElementsAttr::RawElementIterator;
|
||||
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
|
@ -443,6 +495,10 @@ public:
|
|||
/// Gets the integer value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APInt> &values) const;
|
||||
|
||||
/// Iterator access to the integer element values.
|
||||
iterator begin() const { return raw_begin(); }
|
||||
iterator end() const { return raw_end(); }
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
|
||||
};
|
||||
|
@ -451,6 +507,18 @@ public:
|
|||
/// object. Each element is stored as a double.
|
||||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
/// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
|
||||
/// element iterator.
|
||||
class ElementIterator final
|
||||
: public llvm::mapped_iterator<RawElementIterator,
|
||||
std::function<APFloat(const APInt &)>> {
|
||||
friend DenseFPElementsAttr;
|
||||
|
||||
/// Initializes the float element iterator to the specified iterator.
|
||||
ElementIterator(const llvm::fltSemantics &smt, RawElementIterator it);
|
||||
};
|
||||
using iterator = ElementIterator;
|
||||
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
|
@ -465,6 +533,10 @@ public:
|
|||
/// Gets the float value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APFloat> &values) const;
|
||||
|
||||
/// Iterator access to the float element values.
|
||||
iterator begin() const;
|
||||
iterator end() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::DenseFPElements; }
|
||||
};
|
||||
|
|
|
@ -65,7 +65,9 @@ Attribute Attribute::remapFunctionAttrs(
|
|||
return ArrayAttr::get(remappedElts, context);
|
||||
}
|
||||
|
||||
/// NumericAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NumericAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type NumericAttr::getType() const {
|
||||
if (auto boolAttr = dyn_cast<BoolAttr>())
|
||||
|
@ -85,13 +87,17 @@ bool NumericAttr::kindof(Kind kind) {
|
|||
FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
|
||||
}
|
||||
|
||||
/// BoolAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BoolAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||
|
||||
Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
|
||||
|
||||
/// IntegerAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IntegerAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
APInt IntegerAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->getValue();
|
||||
|
@ -103,7 +109,9 @@ Type IntegerAttr::getType() const {
|
|||
return static_cast<ImplType *>(attr)->type;
|
||||
}
|
||||
|
||||
/// FloatAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FloatAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
APFloat FloatAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->getValue();
|
||||
|
@ -122,35 +130,47 @@ double FloatAttr::getValueAsDouble() const {
|
|||
return value.convertToDouble();
|
||||
}
|
||||
|
||||
/// StringAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StringAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StringRef StringAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
/// ArrayAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArrayAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ArrayRef<Attribute> ArrayAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
/// AffineMapAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineMapAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineMap AffineMapAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
/// IntegerSetAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IntegerSetAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IntegerSet IntegerSetAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
/// TypeAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||
|
||||
/// FunctionAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FunctionAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Function *FunctionAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
|
@ -158,7 +178,9 @@ Function *FunctionAttr::getValue() const {
|
|||
|
||||
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
|
||||
/// ElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
VectorOrTensorType ElementsAttr::getType() const {
|
||||
return static_cast<ImplType *>(attr)->type;
|
||||
|
@ -182,13 +204,41 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
}
|
||||
}
|
||||
|
||||
/// SplatElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute SplatElementsAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->elt;
|
||||
}
|
||||
|
||||
/// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RawElementIterator
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static size_t getDenseElementBitwidth(Type eltType) {
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
/// Constructs a new iterator.
|
||||
DenseElementsAttr::RawElementIterator::RawElementIterator(
|
||||
DenseElementsAttr attr, size_t index)
|
||||
: rawData(attr.getRawData().data()), index(index),
|
||||
bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
|
||||
|
||||
/// Accesses the raw APInt value at this iterator position.
|
||||
APInt DenseElementsAttr::RawElementIterator::operator*() const {
|
||||
return readBits(rawData, index * bitWidth, bitWidth);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
size_t DenseElementsAttr::size() const { return getType().getNumElements(); }
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
|
@ -215,12 +265,8 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
}
|
||||
|
||||
// Return the element stored at the 1D index.
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
auto elementType = getType().getElementType();
|
||||
size_t bitWidth =
|
||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||
size_t bitWidth = getDenseElementBitwidth(elementType);
|
||||
APInt rawValueData =
|
||||
readBits(getRawData().data(), valueIndex * bitWidth, bitWidth);
|
||||
|
||||
|
@ -277,10 +323,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
assert(values.size() == type.getNumElements() &&
|
||||
"expected 'values' to contain the same number of elements as 'type'");
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
size_t bitWidth = getDenseElementBitwidth(type.getElementType());
|
||||
std::vector<char> elementData(APInt::getNumWords(bitWidth * values.size()) *
|
||||
APInt::APINT_WORD_SIZE);
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
|
@ -290,22 +333,6 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
return get(type, elementData);
|
||||
}
|
||||
|
||||
/// Parses the raw integer internal value for each dense element into
|
||||
/// 'values'.
|
||||
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
|
||||
auto elementType = getType().getElementType();
|
||||
auto elementNum = getType().getNumElements();
|
||||
values.reserve(elementNum);
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
size_t bitWidth =
|
||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||
const auto *rawData = getRawData().data();
|
||||
for (size_t i = 0, e = elementNum; i != e; ++i)
|
||||
values.push_back(readBits(rawData, i * bitWidth, bitWidth));
|
||||
}
|
||||
|
||||
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
|
||||
/// expected to be a 64-bit aligned storage address.
|
||||
void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) {
|
||||
|
@ -354,7 +381,9 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|||
return result;
|
||||
}
|
||||
|
||||
/// DenseIntElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseIntElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
|
@ -381,11 +410,19 @@ DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
|
|||
}
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
||||
// Simply return the raw integer values.
|
||||
getRawValues(values);
|
||||
values.reserve(size());
|
||||
values.assign(raw_begin(), raw_end());
|
||||
}
|
||||
|
||||
/// DenseFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DenseFPElementsAttr::ElementIterator::ElementIterator(
|
||||
const llvm::fltSemantics &smt, RawElementIterator it)
|
||||
: llvm::mapped_iterator<RawElementIterator,
|
||||
std::function<APFloat(const APInt &)>>(
|
||||
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
|
||||
|
||||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
|
@ -400,18 +437,25 @@ DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type,
|
|||
}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
|
||||
// Get the raw APInt element values.
|
||||
SmallVector<APInt, 8> intValues;
|
||||
getRawValues(intValues);
|
||||
|
||||
// Convert each of the APInt values to an APFloat.
|
||||
auto elementType = getType().getElementType().dyn_cast<FloatType>();
|
||||
const auto &elementSemantics = elementType.getFloatSemantics();
|
||||
for (auto &intValue : intValues)
|
||||
values.push_back(APFloat(elementSemantics, intValue));
|
||||
values.reserve(size());
|
||||
values.assign(begin(), end());
|
||||
}
|
||||
|
||||
/// OpaqueElementsAttr
|
||||
/// Iterator access to the float element values.
|
||||
DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
|
||||
auto elementType = getType().getElementType().cast<FloatType>();
|
||||
const auto &elementSemantics = elementType.getFloatSemantics();
|
||||
return {elementSemantics, raw_begin()};
|
||||
}
|
||||
DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
|
||||
auto elementType = getType().getElementType().cast<FloatType>();
|
||||
const auto &elementSemantics = elementType.getFloatSemantics();
|
||||
return {elementSemantics, raw_end()};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpaqueElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StringRef OpaqueElementsAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->bytes;
|
||||
|
@ -435,7 +479,9 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// SparseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SparseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
||||
return static_cast<ImplType *>(attr)->indices;
|
||||
|
@ -482,7 +528,9 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
return getValues().getValue(it->second);
|
||||
}
|
||||
|
||||
/// NamedAttributeList
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NamedAttributeList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
NamedAttributeList::NamedAttributeList(MLIRContext *context,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
|
|
|
@ -47,15 +47,11 @@ static DenseElementsAttr
|
|||
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
|
||||
QuantizedType quantizedElementType,
|
||||
const UniformQuantizedValueConverter &converter) {
|
||||
// Read real expressed values.
|
||||
SmallVector<APFloat, 8> realValues;
|
||||
realValues.reserve(realFPElementsAttr.getType().getNumElements());
|
||||
realFPElementsAttr.getValues(realValues);
|
||||
|
||||
// Convert to corresponding quantized value attributes.
|
||||
SmallVector<APInt, 8> quantValues(realValues.size());
|
||||
for (size_t i = 0, e = realValues.size(); i < e; ++i) {
|
||||
quantValues[i] = converter.quantizeFloatToInt(realValues[i]);
|
||||
SmallVector<APInt, 8> quantValues;
|
||||
quantValues.reserve(realFPElementsAttr.size());
|
||||
for (APFloat realVal : realFPElementsAttr) {
|
||||
quantValues.push_back(converter.quantizeFloatToInt(realVal));
|
||||
}
|
||||
|
||||
// Cast from an expressed-type-based type to storage-type-based type,
|
||||
|
|
Loading…
Reference in New Issue