[MLIR] Add functionality for constructing a DenseElementAttr from an array of attributes and rerwite DenseElementsAttr::writeBits/readBits to handle non uniform bitwidths. This fixes asan failures that happen when using non uniform bitwidths.

PiperOrigin-RevId: 229815107
This commit is contained in:
River Riddle 2019-01-17 14:11:05 -08:00 committed by jpienaar
parent 40f7535571
commit 0e81d7c420
9 changed files with 182 additions and 197 deletions

View File

@ -337,13 +337,10 @@ public:
/// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type (note all float type are 64 bits).
/// When the value is retrieved, the bits are read from the storage and extend
/// to 64 bits if necessary.
static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
// TODO: Read the data from the attribute list and compress them
// to a character array. Then call the above method to construct the
// attribute.
// Constructs a dense elements attribute from an array of element values. Each
// element attribute value is expected to be an element of 'type'.
static DenseElementsAttr get(VectorOrTensorType type,
ArrayRef<Attribute> values);
@ -351,20 +348,23 @@ public:
ArrayRef<char> getRawData() const;
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
/// in array `rawData`.
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
uint64_t value);
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
/// expected to be a 64-bit aligned storage address.
static void writeBits(char *rawData, size_t bitPos, APInt value);
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData` and return them as the lowest bits of an uint64 integer.
static uint64_t readBits(const char *rawData, size_t bitPos,
size_t bitsWidth);
/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth);
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) {
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
}
protected:
/// Parses the raw integer internal value for each dense element into
/// 'values'.
void getRawValues(SmallVectorImpl<APInt> &values) const;
};
/// An attribute represents a reference to a dense integer vector or tensor
@ -372,10 +372,11 @@ public:
class DenseIntElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
using ImplType = detail::DenseIntElementsAttributeStorage;
using DenseElementsAttr::getValues;
using DenseElementsAttr::ImplType;
// TODO: returns APInts instead of IntegerAttr.
void getValues(SmallVectorImpl<Attribute> &values) const;
/// Gets the integer value of each of the dense elements.
void getValues(SmallVectorImpl<APInt> &values) const;
APInt getValue(ArrayRef<unsigned> indices) const;
@ -388,10 +389,11 @@ public:
class DenseFPElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
using ImplType = detail::DenseFPElementsAttributeStorage;
using DenseElementsAttr::getValues;
using DenseElementsAttr::ImplType;
// TODO: returns APFPs instead of FloatAttr.
void getValues(SmallVectorImpl<Attribute> &values) const;
/// Gets the float value of each of the dense elements.
void getValues(SmallVectorImpl<APFloat> &values) const;
APFloat getValue(ArrayRef<unsigned> indices) const;

View File

@ -107,6 +107,8 @@ public:
ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data);
ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<Attribute> values);
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values);

View File

@ -134,17 +134,6 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
ArrayRef<char> data;
};
/// An attribute representing a reference to a dense integer vector or tensor
/// object.
struct DenseIntElementsAttributeStorage : public DenseElementsAttributeStorage {
size_t bitsWidth;
};
/// An attribute representing a reference to a dense float vector or tensor
/// object.
struct DenseFPElementsAttributeStorage : public DenseElementsAttributeStorage {
};
/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {

View File

@ -150,13 +150,28 @@ Attribute SplatElementsAttr::getValue() const {
/// DenseElementsAttr
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementType = getType().getElementType();
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
cast<DenseIntElementsAttr>().getValues(values);
case Attribute::Kind::DenseIntElements: {
// Get the raw APInt values.
SmallVector<APInt, 8> intValues;
cast<DenseIntElementsAttr>().getValues(intValues);
// Convert each to an IntegerAttr.
for (auto &intVal : intValues)
values.push_back(IntegerAttr::get(elementType, intVal));
return;
case Attribute::Kind::DenseFPElements:
cast<DenseFPElementsAttr>().getValues(values);
}
case Attribute::Kind::DenseFPElements: {
// Get the raw APFloat values.
SmallVector<APFloat, 8> floatValues;
cast<DenseFPElementsAttr>().getValues(floatValues);
// Convert each to an FloatAttr.
for (auto &floatVal : floatValues)
values.push_back(FloatAttr::get(elementType, floatVal));
return;
}
default:
llvm_unreachable("unexpected element type");
}
@ -166,112 +181,89 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(attr)->data;
}
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
/// starting from `rawData`.
void DenseElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
uint64_t value) {
assert(bitWidth <= 64 && "expected bitWidth to be within 64-bits");
// Read the destination bytes which will be written to.
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitWidth;
auto start = data + bitPos / 8;
auto end = data + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
// Clean up the invalid bits in the destination bytes.
dst &= ~(-1UL << (bitPos % 8));
// Get the valid bits of the source value, shift them to right position,
// then add them to the destination bytes.
value <<= bitPos % 8;
dst |= value;
// Write the destination bytes back.
ArrayRef<char> range({dstData, (size_t)(end - start)});
std::copy(range.begin(), range.end(), start);
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits.
uint64_t DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) {
assert(bitsWidth <= 64 && "expected bitWidth to be within 64-bits");
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth;
auto start = rawData + bitPos / 8;
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
dst >>= bitPos % 8;
dst &= ~(-1UL << bitsWidth);
return dst;
}
/// DenseIntElementsAttr
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
/// Parses the raw integer internal value for each dense element into
/// 'values'.
void DenseElementsAttr::getRawValues(SmallVectorImpl<APInt> &values) const {
auto elementType = getType().getElementType();
auto elementNum = getType().getNumElements();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto attr = IntegerAttr::get(getType().getElementType(), value);
values.push_back(attr);
}
} else {
const auto *rawData = getRawData().data();
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto attr =
IntegerAttr::get(getType().getElementType(), value.getSExtValue());
values.push_back(attr);
}
}
}
/// DenseFPElementsAttr
// Construct a FloatAttr wrapping a float value of `elementType` type from its
// bit representation. The APFloat stored in the attribute will have the
// semantics defined by the float semantics of the element type.
static inline FloatAttr makeFloatAttrFromBits(size_t bitWidth, uint64_t bits,
FloatType elementType) {
auto apint = APInt(bitWidth, bits);
auto apfloat = APFloat(elementType.getFloatSemantics(), apint);
return FloatAttr::get(elementType, apfloat);
}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType().getNumElements();
auto elementType = getType().getElementType().dyn_cast<FloatType>();
assert(elementType && "non-float type in FP attribute");
// FIXME: using 64 bits for BF16 because it is currently stored with double
// semantics.
size_t bitWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
const auto *rawData = getRawData().data();
for (size_t i = 0, e = elementNum; i != e; ++i)
values.push_back(readBits(rawData, i * bitWidth, bitWidth));
}
values.reserve(elementNum);
if (bitWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto bitValue : vs) {
values.push_back(makeFloatAttrFromBits(64, bitValue, elementType));
}
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
/// expected to be a 64-bit aligned storage address.
void DenseElementsAttr::writeBits(char *rawData, size_t bitPos, APInt value) {
size_t bitWidth = value.getBitWidth();
// If the bitwidth is 1 we just toggle the specific bit.
if (bitWidth == 1) {
auto *rawIntData = reinterpret_cast<uint64_t *>(rawData);
if (value.isOneValue())
APInt::tcSetBit(rawIntData, bitPos);
else
APInt::tcClearBit(rawIntData, bitPos);
return;
}
for (unsigned i = 0; i < elementNum; ++i) {
uint64_t bits = readBits(getRawData().data(), i * bitWidth, bitWidth);
values.push_back(makeFloatAttrFromBits(bitWidth, bits, elementType));
// If the bit position and width are byte aligned, write the storage directly
// to the data.
if ((bitWidth % 8) == 0 && (bitPos % 8) == 0) {
std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
bitWidth / 8, rawData + (bitPos / 8));
return;
}
// Otherwise, convert the raw data into an APInt and insert the value at the
// specified bit position.
size_t totalWords = APInt::getNumWords((bitPos % 64) + bitWidth);
llvm::MutableArrayRef<uint64_t> rawIntData(
reinterpret_cast<uint64_t *>(rawData) + (bitPos / 64), totalWords);
APInt tempStorage(totalWords * 64, rawIntData);
tempStorage.insertBits(value, bitPos % 64);
// Copy the value back to the raw data.
std::copy_n(tempStorage.getRawData(), rawIntData.size(), rawIntData.data());
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData`. 'rawData' is expected to be a 64-bit aligned storage address.
APInt DenseElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitWidth) {
// Reinterpret the raw data as a uint64_t word array and extract the value
// starting at 'bitPos'.
APInt result(bitWidth, 0);
const uint64_t *intData = reinterpret_cast<const uint64_t *>(rawData);
APInt::tcExtract(const_cast<uint64_t *>(result.getRawData()),
result.getNumWords(), intData, bitWidth, bitPos);
return result;
}
/// DenseIntElementsAttr
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
// Simply return the raw integer values.
getRawValues(values);
}
/// DenseFPElementsAttr
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
// Get the raw APInt element values.
SmallVector<APInt, 8> intValues;
getRawValues(intValues);
// Convert each of the APInt values to an APFloat.
auto elementType = getType().getElementType().dyn_cast<FloatType>();
const auto &elementSemantics = elementType.getFloatSemantics();
for (auto &intValue : intValues)
values.push_back(APFloat(elementSemantics, intValue));
}
/// OpaqueElementsAttr

View File

@ -171,6 +171,11 @@ ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
return DenseElementsAttr::get(type, data);
}
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<Attribute> values) {
return DenseElementsAttr::get(type, values);
}
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {

View File

@ -1050,37 +1050,68 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
if (!existing.second)
return *existing.first;
// Otherwise, allocate a new one, unique it and return it.
auto eltType = type.getElementType();
switch (eltType.getKind()) {
Attribute::Kind kind;
switch (type.getElementType().getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
case StandardTypes::F32:
case StandardTypes::F64: {
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseFPElementsAttributeStorage{
{{{Attribute::Kind::DenseFPElements, /*isOrContainsFunction=*/false},
type},
{copy, data.size()}}};
return *existing.first = result;
}
case StandardTypes::Integer: {
auto width = eltType.cast<IntegerType>().getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseIntElementsAttributeStorage{
{{{Attribute::Kind::DenseIntElements, /*isOrContainsFunction=*/false},
type},
{copy, data.size()}},
width};
return *existing.first = result;
}
case StandardTypes::F64:
kind = Attribute::Kind::DenseFPElements;
break;
case StandardTypes::Integer:
kind = Attribute::Kind::DenseIntElements;
break;
default:
llvm_unreachable("unexpected element type");
}
// Otherwise, allocate a new one, unique it and return it.
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
auto *result = impl.allocator.Allocate<DenseElementsAttributeStorage>();
new (result) DenseElementsAttributeStorage{
{{kind, /*isOrContainsFunction=*/false}, type}, {copy, data.size()}};
return *existing.first = result;
}
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<Attribute> values) {
assert(type.getElementType().isIntOrFloat() &&
"expected int or float element type");
// FIXME: using 64 bits for BF16 because it is currently stored with double
// semantics.
auto eltType = type.getElementType();
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
// Compress the attribute values into a character buffer.
SmallVector<char, 8> data(type.getSizeInBits() * 8L);
for (unsigned i = 0, e = values.size(); i < e; ++i) {
unsigned bitPos = i * bitWidth;
APInt intVal;
switch (eltType.getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
case StandardTypes::F32:
case StandardTypes::F64:
assert(eltType == values[i].cast<FloatAttr>().getType() &&
"expected attribute value to have element type");
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
break;
case StandardTypes::Integer:
assert(eltType == values[i].cast<IntegerAttr>().getType() &&
"expected attribute value to have element type");
intVal = values[i].cast<IntegerAttr>().getValue();
break;
default:
llvm_unreachable("unexpected element type");
}
assert(intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type");
writeBits(data.data(), bitPos, intVal);
}
return get(type, data);
}
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,

View File

@ -670,8 +670,7 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
namespace {
class TensorLiteralParser {
public:
TensorLiteralParser(Parser &p, Type eltTy)
: p(p), eltTy(eltTy), currBitPos(0) {}
TensorLiteralParser(Parser &p, Type eltTy) : p(p), eltTy(eltTy) {}
ParseResult parse() {
if (p.getToken().isNot(Token::l_square))
@ -679,9 +678,7 @@ public:
return parseList(shape);
}
ArrayRef<char> getValues() const {
return {reinterpret_cast<const char *>(storage.data()), storage.size() * 8};
}
ArrayRef<Attribute> getValues() const { return storage; }
ArrayRef<int> getShape() const { return shape; }
@ -698,28 +695,10 @@ private:
/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
ParseResult parseList(llvm::SmallVectorImpl<int> &dims);
void addToStorage(uint64_t value) {
// Only tensors of integers or floats are supported.
// FIXME: use full word to store BF16 as double because APFloat, which we
// use to work with floats, does not have support for BF16 yet.
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
if (bitWidth == 64)
storage.push_back(value);
if (currBitPos + bitWidth > storage.size() * 64)
storage.push_back(0L);
auto *rawData = reinterpret_cast<char *>(storage.data());
DenseElementsAttr::writeBits(rawData, currBitPos, bitWidth, value);
currBitPos += bitWidth;
}
Parser &p;
Type eltTy;
size_t currBitPos;
SmallVector<int, 4> shape;
std::vector<uint64_t> storage;
std::vector<Attribute> storage;
};
} // namespace
@ -754,30 +733,22 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
size_t bitWidth = eltTy.isBF16() ? 64 : eltTy.getIntOrFloatBitWidth();
assert(apInt.getBitWidth() == bitWidth);
(void)bitWidth;
addToStorage(apInt.getRawData()[0]);
(void)apInt;
break;
}
case StandardTypes::Integer: {
if (!result.isa<IntegerAttr>())
return p.emitError("expected tensor literal element has integer type");
auto value = result.cast<IntegerAttr>().getValue();
auto bitWidth = eltTy.getIntOrFloatBitWidth();
if (value.getMinSignedBits() > bitWidth)
if (value.getMinSignedBits() > eltTy.getIntOrFloatBitWidth())
return p.emitError("tensor literal element has more bits than that "
"specified in the type");
// FIXME: Handle larger than 64-bit types more gracefully.
if (bitWidth > 64)
return p.emitError("tensor literal element with more than 64-bits is "
"not currently supported");
addToStorage(value.getSExtValue());
break;
}
default:
return p.emitError("expected integer or float tensor element");
}
storage.push_back(result);
break;
}
default:

View File

@ -750,15 +750,6 @@ func @elementsattr_toolarge2() -> () {
// -----
// FIXME: Handle larger than 64-bit types more gracefully.
func @elementsattr_larger_than_64_bits() -> () {
^bb0:
"fooi67"(){bar: dense<vector<1x1x1xi67>, [[[-5]]]>} : () -> () // expected-error {{tensor literal element with more than 64-bits is not currently supported}}
return
}
// -----
func @elementsattr_malformed_opaque() -> () {
^bb0:
"foo"(){bar: opaque<tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}}

View File

@ -593,6 +593,8 @@ func @densetensorattr() -> () {
"fooi64"(){bar: dense<tensor<2x1x4xi64>, [[[1, -2, 1, 2]], [[0, 3, -1, 2]]]>} : () -> ()
// CHECK: "fooi64"() {bar: dense<tensor<1x1x1xi64>, {{\[\[\[}}-5]]]>} : () -> ()
"fooi64"(){bar: dense<tensor<1x1x1xi64>, [[[-5]]]>} : () -> ()
// CHECK: "fooi67"() {bar: dense<vector<1x1x4xi67>, {{\[\[\[}}-5, 4, 6, 2]]]>} : () -> ()
"fooi67"(){bar: dense<vector<1x1x4xi67>, [[[-5, 4, 6, 2]]]>} : () -> ()
// CHECK: "foo2"() {bar: dense<tensor<0xi32>, []>} : () -> ()
"foo2"(){bar: dense<tensor<0 x i32>, []>} : () -> ()