forked from OSchip/llvm-project
Add support for constructing DenseIntElementsAttr with an array of APInt and
DenseFPElementsAttr with an array of APFloat. PiperOrigin-RevId: 235581794
This commit is contained in:
parent
3f644705eb
commit
f1f86eac60
|
@ -353,7 +353,7 @@ public:
|
|||
using ImplType = detail::DenseElementsAttributeStorage;
|
||||
|
||||
/// It assumes the elements in the input array have been truncated to the bits
|
||||
/// width specified by the element type (note all float type are 64 bits).
|
||||
/// width specified by the element type.
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
|
||||
|
||||
// Constructs a dense elements attribute from an array of element values. Each
|
||||
|
@ -383,6 +383,11 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
// Each APInt value is expected to have the same bitwidth as the element type
|
||||
// of 'type'.
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Parses the raw integer internal value for each dense element into
|
||||
/// 'values'.
|
||||
void getRawValues(SmallVectorImpl<APInt> &values) const;
|
||||
|
@ -393,9 +398,16 @@ protected:
|
|||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
using DenseElementsAttr::ImplType;
|
||||
|
||||
// Constructs a dense integer elements attribute from an array of APInt
|
||||
// values. Each APInt value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
static DenseIntElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<APInt> values);
|
||||
|
||||
/// Gets the integer value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APInt> &values) const;
|
||||
|
||||
|
@ -410,9 +422,16 @@ public:
|
|||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
using DenseElementsAttr::ImplType;
|
||||
|
||||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
static DenseFPElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<APFloat> values);
|
||||
|
||||
/// Gets the float value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APFloat> &values) const;
|
||||
|
||||
|
|
|
@ -251,6 +251,27 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|||
return static_cast<ImplType *>(attr)->data;
|
||||
}
|
||||
|
||||
// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
// Each APInt value is expected to have the same bitwidth as the element type
|
||||
// of 'type'.
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(values.size() == type.getNumElements() &&
|
||||
"expected 'values' to contain the same number of elements as 'type'");
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
std::vector<char> elementData(APInt::getNumWords(bitWidth * values.size()) *
|
||||
APInt::APINT_WORD_SIZE);
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
assert(values[i].getBitWidth() == bitWidth);
|
||||
writeBits(elementData.data(), i * bitWidth, values[i]);
|
||||
}
|
||||
return get(type, elementData);
|
||||
}
|
||||
|
||||
/// Parses the raw integer internal value for each dense element into
|
||||
/// 'values'.
|
||||
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
|
||||
|
@ -317,6 +338,14 @@ APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|||
|
||||
/// DenseIntElementsAttr
|
||||
|
||||
// Constructs a dense integer elements attribute from an array of APInt
|
||||
// values. Each APInt value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(VectorOrTensorType type,
|
||||
ArrayRef<APInt> values) {
|
||||
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
|
||||
}
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
||||
// Simply return the raw integer values.
|
||||
getRawValues(values);
|
||||
|
@ -324,6 +353,18 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
|||
|
||||
/// DenseFPElementsAttr
|
||||
|
||||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
DenseFPElementsAttr DenseFPElementsAttr::get(VectorOrTensorType type,
|
||||
ArrayRef<APFloat> values) {
|
||||
// Convert the APFloat values to APInt and create a dense elements attribute.
|
||||
std::vector<APInt> intValues(values.size());
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i)
|
||||
intValues[i] = values[i].bitcastToAPInt();
|
||||
return DenseElementsAttr::get(type, intValues).cast<DenseFPElementsAttr>();
|
||||
}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
|
||||
// Get the raw APInt element values.
|
||||
SmallVector<APInt, 8> intValues;
|
||||
|
|
|
@ -1063,7 +1063,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
ArrayRef<char> data) {
|
||||
auto bitsRequired = type.getSizeInBits();
|
||||
(void)bitsRequired;
|
||||
assert((bitsRequired <= data.size() * 8L) &&
|
||||
assert((bitsRequired <= data.size() * APInt::APINT_WORD_SIZE) &&
|
||||
"Input data bit size should be larger than that type requires");
|
||||
|
||||
auto &impl = type.getContext()->getImpl();
|
||||
|
@ -1113,11 +1113,10 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
|
||||
// Compress the attribute values into a character buffer.
|
||||
SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) * 8L);
|
||||
SmallVector<char, 8> data(APInt::getNumWords(bitWidth * values.size()) *
|
||||
APInt::APINT_WORD_SIZE);
|
||||
APInt intVal;
|
||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||
unsigned bitPos = i * bitWidth;
|
||||
|
||||
APInt intVal;
|
||||
switch (eltType.getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
|
@ -1137,7 +1136,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
}
|
||||
assert(intVal.getBitWidth() == bitWidth &&
|
||||
"expected value to have same bitwidth as element type");
|
||||
writeBits(data.data(), bitPos, intVal);
|
||||
writeBits(data.data(), i * bitWidth, intVal);
|
||||
}
|
||||
return get(type, data);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue