forked from OSchip/llvm-project
[MLIR] Add functionality for constructing a DenseElementAttr from an array of attributes and rerwite DenseElementsAttr::writeBits/readBits to handle non uniform bitwidths. This fixes asan failures that happen when using non uniform bitwidths.
PiperOrigin-RevId: 229815107
This commit is contained in:
parent
40f7535571
commit
0e81d7c420
|
@ -337,13 +337,10 @@ public:
|
|||
|
||||
/// It assumes the elements in the input array have been truncated to the bits
|
||||
/// width specified by the element type (note all float type are 64 bits).
|
||||
/// When the value is retrieved, the bits are read from the storage and extend
|
||||
/// to 64 bits if necessary.
|
||||
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
|
||||
|
||||
// TODO: Read the data from the attribute list and compress them
|
||||
// to a character array. Then call the above method to construct the
|
||||
// attribute.
|
||||
// Constructs a dense elements attribute from an array of element values. Each
|
||||
// element attribute value is expected to be an element of 'type'.
|
||||
static DenseElementsAttr get(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values);
|
||||
|
||||
|
@ -351,20 +348,23 @@ public:
|
|||
|
||||
ArrayRef<char> getRawData() const;
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
|
||||
/// in array `rawData`.
|
||||
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value);
|
||||
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
|
||||
/// expected to be a 64-bit aligned storage address.
|
||||
static void writeBits(char *rawData, size_t bitPos, APInt value);
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
|
||||
/// `rawData` and return them as the lowest bits of an uint64 integer.
|
||||
static uint64_t readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth);
|
||||
/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
|
||||
static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth);
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Parses the raw integer internal value for each dense element into
|
||||
/// 'values'.
|
||||
void getRawValues(SmallVectorImpl<APInt> &values) const;
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a dense integer vector or tensor
|
||||
|
@ -372,10 +372,11 @@ public:
|
|||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using ImplType = detail::DenseIntElementsAttributeStorage;
|
||||
using DenseElementsAttr::getValues;
|
||||
using DenseElementsAttr::ImplType;
|
||||
|
||||
// TODO: returns APInts instead of IntegerAttr.
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
/// Gets the integer value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APInt> &values) const;
|
||||
|
||||
APInt getValue(ArrayRef<unsigned> indices) const;
|
||||
|
||||
|
@ -388,10 +389,11 @@ public:
|
|||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using ImplType = detail::DenseFPElementsAttributeStorage;
|
||||
using DenseElementsAttr::getValues;
|
||||
using DenseElementsAttr::ImplType;
|
||||
|
||||
// TODO: returns APFPs instead of FloatAttr.
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
/// Gets the float value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APFloat> &values) const;
|
||||
|
||||
APFloat getValue(ArrayRef<unsigned> indices) const;
|
||||
|
||||
|
|
|
@ -107,6 +107,8 @@ public:
|
|||
ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
|
||||
ArrayRef<char> data);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values);
|
||||
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
|
|
|
@ -134,17 +134,6 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
|
|||
ArrayRef<char> data;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a dense integer vector or tensor
|
||||
/// object.
|
||||
struct DenseIntElementsAttributeStorage : public DenseElementsAttributeStorage {
|
||||
size_t bitsWidth;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a dense float vector or tensor
|
||||
/// object.
|
||||
struct DenseFPElementsAttributeStorage : public DenseElementsAttributeStorage {
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a tensor constant with opaque
|
||||
/// content.
|
||||
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
|
|
|
@ -150,13 +150,28 @@ Attribute SplatElementsAttr::getValue() const {
|
|||
/// DenseElementsAttr
|
||||
|
||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementType = getType().getElementType();
|
||||
switch (getKind()) {
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
cast<DenseIntElementsAttr>().getValues(values);
|
||||
case Attribute::Kind::DenseIntElements: {
|
||||
// Get the raw APInt values.
|
||||
SmallVector<APInt, 8> intValues;
|
||||
cast<DenseIntElementsAttr>().getValues(intValues);
|
||||
|
||||
// Convert each to an IntegerAttr.
|
||||
for (auto &intVal : intValues)
|
||||
values.push_back(IntegerAttr::get(elementType, intVal));
|
||||
return;
|
||||
case Attribute::Kind::DenseFPElements:
|
||||
cast<DenseFPElementsAttr>().getValues(values);
|
||||
}
|
||||
case Attribute::Kind::DenseFPElements: {
|
||||
// Get the raw APFloat values.
|
||||
SmallVector<APFloat, 8> floatValues;
|
||||
cast<DenseFPElementsAttr>().getValues(floatValues);
|
||||
|
||||
// Convert each to an FloatAttr.
|
||||
for (auto &floatVal : floatValues)
|
||||
values.push_back(FloatAttr::get(elementType, floatVal));
|
||||
return;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
|
@ -166,112 +181,89 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|||
return static_cast<ImplType *>(attr)->data;
|
||||
}
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
||||
/// starting from `rawData`.
|
||||
void DenseElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value) {
|
||||
assert(bitWidth <= 64 && "expected bitWidth to be within 64-bits");
|
||||
|
||||
// Read the destination bytes which will be written to.
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitWidth;
|
||||
auto start = data + bitPos / 8;
|
||||
auto end = data + endPos / 8 + (endPos % 8 != 0);
|
||||
std::copy(start, end, dstData);
|
||||
|
||||
// Clean up the invalid bits in the destination bytes.
|
||||
dst &= ~(-1UL << (bitPos % 8));
|
||||
|
||||
// Get the valid bits of the source value, shift them to right position,
|
||||
// then add them to the destination bytes.
|
||||
value <<= bitPos % 8;
|
||||
dst |= value;
|
||||
|
||||
// Write the destination bytes back.
|
||||
ArrayRef<char> range({dstData, (size_t)(end - start)});
|
||||
std::copy(range.begin(), range.end(), start);
|
||||
}
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
||||
/// and put them in the lowest bits.
|
||||
uint64_t DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth) {
|
||||
assert(bitsWidth <= 64 && "expected bitWidth to be within 64-bits");
|
||||
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitsWidth;
|
||||
auto start = rawData + bitPos / 8;
|
||||
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
|
||||
std::copy(start, end, dstData);
|
||||
|
||||
dst >>= bitPos % 8;
|
||||
dst &= ~(-1UL << bitsWidth);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/// DenseIntElementsAttr
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
||||
/// Parses the raw integer internal value for each dense element into
|
||||
/// 'values'.
|
||||
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
|
||||
auto elementType = getType().getElementType();
|
||||
auto elementNum = getType().getNumElements();
|
||||
values.reserve(elementNum);
|
||||
if (bitsWidth == 64) {
|
||||
ArrayRef<int64_t> vs(
|
||||
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
for (auto value : vs) {
|
||||
auto attr = IntegerAttr::get(getType().getElementType(), value);
|
||||
values.push_back(attr);
|
||||
}
|
||||
} else {
|
||||
const auto *rawData = getRawData().data();
|
||||
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
|
||||
uint64_t bits = readBits(rawData, pos, bitsWidth);
|
||||
APInt value(bitsWidth, bits, /*isSigned=*/true);
|
||||
auto attr =
|
||||
IntegerAttr::get(getType().getElementType(), value.getSExtValue());
|
||||
values.push_back(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DenseFPElementsAttr
|
||||
|
||||
// Construct a FloatAttr wrapping a float value of `elementType` type from its
|
||||
// bit representation. The APFloat stored in the attribute will have the
|
||||
// semantics defined by the float semantics of the element type.
|
||||
static inline FloatAttr makeFloatAttrFromBits(size_t bitWidth, uint64_t bits,
|
||||
FloatType elementType) {
|
||||
auto apint = APInt(bitWidth, bits);
|
||||
auto apfloat = APFloat(elementType.getFloatSemantics(), apint);
|
||||
return FloatAttr::get(elementType, apfloat);
|
||||
}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementNum = getType().getNumElements();
|
||||
auto elementType = getType().getElementType().dyn_cast<FloatType>();
|
||||
assert(elementType && "non-float type in FP attribute");
|
||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
||||
// semantics.
|
||||
size_t bitWidth =
|
||||
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
|
||||
const auto *rawData = getRawData().data();
|
||||
for (size_t i = 0, e = elementNum; i != e; ++i)
|
||||
values.push_back(readBits(rawData, i * bitWidth, bitWidth));
|
||||
}
|
||||
|
||||
values.reserve(elementNum);
|
||||
if (bitWidth == 64) {
|
||||
ArrayRef<int64_t> vs(
|
||||
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
for (auto bitValue : vs) {
|
||||
values.push_back(makeFloatAttrFromBits(64, bitValue, elementType));
|
||||
}
|
||||
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
|
||||
/// expected to be a 64-bit aligned storage address.
|
||||
void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) {
|
||||
size_t bitWidth = value.getBitWidth();
|
||||
|
||||
// If the bitwidth is 1 we just toggle the specific bit.
|
||||
if (bitWidth == 1) {
|
||||
auto *rawIntData = reinterpret_cast<uint64_t *>(rawData);
|
||||
if (value.isOneValue())
|
||||
APInt::tcSetBit(rawIntData, bitPos);
|
||||
else
|
||||
APInt::tcClearBit(rawIntData, bitPos);
|
||||
return;
|
||||
}
|
||||
for (unsigned i = 0; i < elementNum; ++i) {
|
||||
uint64_t bits = readBits(getRawData().data(), i * bitWidth, bitWidth);
|
||||
values.push_back(makeFloatAttrFromBits(bitWidth, bits, elementType));
|
||||
|
||||
// If the bit position and width are byte aligned, write the storage directly
|
||||
// to the data.
|
||||
if ((bitWidth % 8) == 0 && (bitPos % 8) == 0) {
|
||||
std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
|
||||
bitWidth / 8, rawData + (bitPos / 8));
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, convert the raw data into an APInt and insert the value at the
|
||||
// specified bit position.
|
||||
size_t totalWords = APInt::getNumWords((bitPos % 64) + bitWidth);
|
||||
llvm::MutableArrayRef<uint64_t> rawIntData(
|
||||
reinterpret_cast<uint64_t *>(rawData) + (bitPos / 64), totalWords);
|
||||
APInt tempStorage(totalWords * 64, rawIntData);
|
||||
tempStorage.insertBits(value, bitPos % 64);
|
||||
|
||||
// Copy the value back to the raw data.
|
||||
std::copy_n(tempStorage.getRawData(), rawIntData.size(), rawIntData.data());
|
||||
}
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
|
||||
/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
|
||||
APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitWidth) {
|
||||
// Reinterpret the raw data as a uint64_t word array and extract the value
|
||||
// starting at 'bitPos'.
|
||||
APInt result(bitWidth, 0);
|
||||
const uint64_t *intData = reinterpret_cast<const uint64_t *>(rawData);
|
||||
APInt::tcExtract(const_cast<uint64_t *>(result.getRawData()),
|
||||
result.getNumWords(), intData, bitWidth, bitPos);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// DenseIntElementsAttr
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
||||
// Simply return the raw integer values.
|
||||
getRawValues(values);
|
||||
}
|
||||
|
||||
/// DenseFPElementsAttr
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
|
||||
// Get the raw APInt element values.
|
||||
SmallVector<APInt, 8> intValues;
|
||||
getRawValues(intValues);
|
||||
|
||||
// Convert each of the APInt values to an APFloat.
|
||||
auto elementType = getType().getElementType().dyn_cast<FloatType>();
|
||||
const auto &elementSemantics = elementType.getFloatSemantics();
|
||||
for (auto &intValue : intValues)
|
||||
values.push_back(APFloat(elementSemantics, intValue));
|
||||
}
|
||||
|
||||
/// OpaqueElementsAttr
|
||||
|
|
|
@ -171,6 +171,11 @@ ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
|
|||
return DenseElementsAttr::get(type, data);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
return DenseElementsAttr::get(type, values);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
|
|
|
@ -1050,37 +1050,68 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto eltType = type.getElementType();
|
||||
switch (eltType.getKind()) {
|
||||
Attribute::Kind kind;
|
||||
switch (type.getElementType().getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64: {
|
||||
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
new (result) DenseFPElementsAttributeStorage{
|
||||
{{{Attribute::Kind::DenseFPElements, /*isOrContainsFunction=*/false},
|
||||
type},
|
||||
{copy, data.size()}}};
|
||||
return *existing.first = result;
|
||||
}
|
||||
case StandardTypes::Integer: {
|
||||
auto width = eltType.cast<IntegerType>().getWidth();
|
||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
new (result) DenseIntElementsAttributeStorage{
|
||||
{{{Attribute::Kind::DenseIntElements, /*isOrContainsFunction=*/false},
|
||||
type},
|
||||
{copy, data.size()}},
|
||||
width};
|
||||
return *existing.first = result;
|
||||
}
|
||||
case StandardTypes::F64:
|
||||
kind = Attribute::Kind::DenseFPElements;
|
||||
break;
|
||||
case StandardTypes::Integer:
|
||||
kind = Attribute::Kind::DenseIntElements;
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
|
||||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
auto *result = impl.allocator.Allocate<DenseElementsAttributeStorage>();
|
||||
new (result) DenseElementsAttributeStorage{
|
||||
{{kind, /*isOrContainsFunction=*/false}, type}, {copy, data.size()}};
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
assert(type.getElementType().isIntOrFloat() &&
|
||||
"expected int or float element type");
|
||||
|
||||
// FIXME: using 64 bits for BF16 because it is currently stored with double
|
||||
// semantics.
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
|
||||
// Compress the attribute values into a character buffer.
|
||||
SmallVector<char, 8> data(type.getSizeInBits() * 8L);
|
||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||
unsigned bitPos = i * bitWidth;
|
||||
|
||||
APInt intVal;
|
||||
switch (eltType.getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64:
|
||||
assert(eltType == values[i].cast<FloatAttr>().getType() &&
|
||||
"expected attribute value to have element type");
|
||||
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
|
||||
break;
|
||||
case StandardTypes::Integer:
|
||||
assert(eltType == values[i].cast<IntegerAttr>().getType() &&
|
||||
"expected attribute value to have element type");
|
||||
intVal = values[i].cast<IntegerAttr>().getValue();
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
assert(intVal.getBitWidth() == bitWidth &&
|
||||
"expected value to have same bitwidth as element type");
|
||||
writeBits(data.data(), bitPos, intVal);
|
||||
}
|
||||
return get(type, data);
|
||||
}
|
||||
|
||||
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
|
||||
|
|
|
@ -670,8 +670,7 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
|
|||
namespace {
|
||||
class TensorLiteralParser {
|
||||
public:
|
||||
TensorLiteralParser(Parser &p, Type eltTy)
|
||||
: p(p), eltTy(eltTy), currBitPos(0) {}
|
||||
TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {}
|
||||
|
||||
ParseResult parse() {
|
||||
if (p.getToken().isNot(Token::l_square))
|
||||
|
@ -679,9 +678,7 @@ public:
|
|||
return parseList(shape);
|
||||
}
|
||||
|
||||
ArrayRef<char> getValues() const {
|
||||
return {reinterpret_cast<const char *>(storage.data()), storage.size() * 8};
|
||||
}
|
||||
ArrayRef<Attribute> getValues() const { return storage; }
|
||||
|
||||
ArrayRef<int> getShape() const { return shape; }
|
||||
|
||||
|
@ -698,28 +695,10 @@ private:
|
|||
/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
|
||||
ParseResult parseList(llvm::SmallVectorImpl<int> &dims);
|
||||
|
||||
void addToStorage(uint64_t value) {
|
||||
// Only tensors of integers or floats are supported.
|
||||
// FIXME: use full word to store BF16 as double because APFloat, which we
|
||||
// use to work with floats, does not have support for BF16 yet.
|
||||
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
|
||||
|
||||
if (bitWidth == 64)
|
||||
storage.push_back(value);
|
||||
|
||||
if (currBitPos + bitWidth > storage.size() * 64)
|
||||
storage.push_back(0L);
|
||||
|
||||
auto *rawData = reinterpret_cast<char *>(storage.data());
|
||||
DenseElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
|
||||
currBitPos += bitWidth;
|
||||
}
|
||||
|
||||
Parser &p;
|
||||
Type eltTy;
|
||||
size_t currBitPos;
|
||||
SmallVector<int, 4> shape;
|
||||
std::vector<uint64_t> storage;
|
||||
std::vector<Attribute> storage;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -754,30 +733,22 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
|
||||
assert(apInt.getBitWidth() == bitWidth);
|
||||
(void)bitWidth;
|
||||
|
||||
addToStorage(apInt.getRawData()[0]);
|
||||
(void)apInt;
|
||||
break;
|
||||
}
|
||||
case StandardTypes::Integer: {
|
||||
if (!result.isa<IntegerAttr>())
|
||||
return p.emitError("expected tensor literal element has integer type");
|
||||
auto value = result.cast<IntegerAttr>().getValue();
|
||||
auto bitWidth = eltTy.getIntOrFloatBitWidth();
|
||||
if (value.getMinSignedBits() > bitWidth)
|
||||
if (value.getMinSignedBits() > eltTy.getIntOrFloatBitWidth())
|
||||
return p.emitError("tensor literal element has more bits than that "
|
||||
"specified in the type");
|
||||
|
||||
// FIXME: Handle larger than 64-bit types more gracefully.
|
||||
if (bitWidth > 64)
|
||||
return p.emitError("tensor literal element with more than 64-bits is "
|
||||
"not currently supported");
|
||||
|
||||
addToStorage(value.getSExtValue());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return p.emitError("expected integer or float tensor element");
|
||||
}
|
||||
storage.push_back(result);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -750,15 +750,6 @@ func @elementsattr_toolarge2() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
// FIXME: Handle larger than 64-bit types more gracefully.
|
||||
func @elementsattr_larger_than_64_bits() -> () {
|
||||
^bb0:
|
||||
"fooi67"(){bar: dense<vector<1x1x1xi67>, [[[-5]]]>} : () -> () // expected-error {{tensor literal element with more than 64-bits is not currently supported}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @elementsattr_malformed_opaque() -> () {
|
||||
^bb0:
|
||||
"foo"(){bar: opaque<tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}}
|
||||
|
|
|
@ -593,6 +593,8 @@ func @densetensorattr() -> () {
|
|||
"fooi64"(){bar: dense<tensor<2x1x4xi64>, [[[1, -2, 1, 2]], [[0, 3, -1, 2]]]>} : () -> ()
|
||||
// CHECK: "fooi64"() {bar: dense<tensor<1x1x1xi64>, {{\[\[\[}}-5]]]>} : () -> ()
|
||||
"fooi64"(){bar: dense<tensor<1x1x1xi64>, [[[-5]]]>} : () -> ()
|
||||
// CHECK: "fooi67"() {bar: dense<vector<1x1x4xi67>, {{\[\[\[}}-5, 4, 6, 2]]]>} : () -> ()
|
||||
"fooi67"(){bar: dense<vector<1x1x4xi67>, [[[-5, 4, 6, 2]]]>} : () -> ()
|
||||
|
||||
// CHECK: "foo2"() {bar: dense<tensor<0xi32>, []>} : () -> ()
|
||||
"foo2"(){bar: dense<tensor<0 x i32>, []>} : () -> ()
|
||||
|
|
Loading…
Reference in New Issue