Add support for constructing DenseIntElementsAttr with an array of APInt and

DenseFPElementsAttr with an array of APFloat.

PiperOrigin-RevId: 235581794
This commit is contained in:
River Riddle 2019-02-25 12:32:43 -08:00 committed by jpienaar
parent 3f644705eb
commit f1f86eac60
3 changed files with 66 additions and 7 deletions

View File

@ -353,7 +353,7 @@ public:
using ImplType = detail::DenseElementsAttributeStorage;
/// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type (note all float type are 64 bits).
/// width specified by the element type.
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
// Constructs a dense elements attribute from an array of element values. Each
@ -383,6 +383,11 @@ public:
}
protected:
// Constructs a dense elements attribute from an array of raw APInt values.
// Each APInt value is expected to have the same bitwidth as the element type
// of 'type'.
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<APInt> values);
/// Parses the raw integer internal value for each dense element into
/// 'values'.
void getRawValues(SmallVectorImpl<APInt> &values) const;
@ -393,9 +398,16 @@ protected:
class DenseIntElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
using DenseElementsAttr::ImplType;
// Constructs a dense integer elements attribute from an array of APInt
// values. Each APInt value is expected to have the same bitwidth as the
// element type of 'type'.
static DenseIntElementsAttr get(VectorOrTensorType type,
ArrayRef<APInt> values);
/// Gets the integer value of each of the dense elements.
void getValues(SmallVectorImpl<APInt> &values) const;
@ -410,9 +422,16 @@ public:
class DenseFPElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
using DenseElementsAttr::ImplType;
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
// element type of 'type'.
static DenseFPElementsAttr get(VectorOrTensorType type,
ArrayRef<APFloat> values);
/// Gets the float value of each of the dense elements.
void getValues(SmallVectorImpl<APFloat> &values) const;

View File

@ -251,6 +251,27 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(attr)->data;
}
// Constructs a dense elements attribute from an array of raw APInt values.
// Each APInt value is expected to have the same bitwidth as the element type
// of 'type'.
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<APInt> values) {
assert(values.size() == type.getNumElements() &&
"expected 'values' to contain the same number of elements as 'type'");
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
auto eltType = type.getElementType();
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
std::vector<char> elementData(APInt::getNumWords(bitWidth * values.size()) *
APInt::APINT_WORD_SIZE);
for (unsigned i = 0, e = values.size(); i != e; ++i) {
assert(values[i].getBitWidth() == bitWidth);
writeBits(elementData.data(), i * bitWidth, values[i]);
}
return get(type, elementData);
}
/// Parses the raw integer internal value for each dense element into
/// 'values'.
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
@ -317,6 +338,14 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
/// DenseIntElementsAttr
// Constructs a dense integer elements attribute from an array of APInt
// values. Each APInt value is expected to have the same bitwidth as the
// element type of 'type'.
DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
ArrayRef<APInt> values) {
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
}
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
// Simply return the raw integer values.
getRawValues(values);
@ -324,6 +353,18 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
/// DenseFPElementsAttr
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
// element type of 'type'.
DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type,
ArrayRef<APFloat> values) {
// Convert the APFloat values to APInt and create a dense elements attribute.
std::vector<APInt> intValues(values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i)
intValues[i] = values[i].bitcastToAPInt();
return DenseElementsAttr::get(type, intValues).cast<DenseFPElementsAttr>();
}
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
// Get the raw APInt element values.
SmallVector<APInt, 8> intValues;

View File

@ -1063,7 +1063,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<char> data) {
auto bitsRequired = type.getSizeInBits();
(void)bitsRequired;
assert((bitsRequired <= data.size() * 8L) &&
assert((bitsRequired <= data.size() * APInt::APINT_WORD_SIZE) &&
"Input data bit size should be larger than that type requires");
auto &impl = type.getContext()->getImpl();
@ -1113,11 +1113,10 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
// Compress the attribute values into a character buffer.
SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) * 8L);
SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
APInt::APINT_WORD_SIZE);
APInt intVal;
for (unsigned i = 0, e = values.size(); i < e; ++i) {
unsigned bitPos = i * bitWidth;
APInt intVal;
switch (eltType.getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
@ -1137,7 +1136,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
}
assert(intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type");
writeBits(data.data(), bitPos, intVal);
writeBits(data.data(), i * bitWidth, intVal);
}
return get(type, data);
}