Refactor DenseElementsAttr to support auto-splatting the dense data on construction. This essentially means that we always auto-detect splat data and only store the minimum amount of data necessary. Support for parsing dense splats, and removing SplatElementsAttr(now that it is redundant) will come in followup cls

PiperOrigin-RevId: 252720561
This commit is contained in:
River Riddle 2019-06-11 16:14:17 -07:00 committed by Mehdi Amini
parent 5da741f671
commit d8cd96bc8b
8 changed files with 444 additions and 104 deletions

View File

@ -47,8 +47,6 @@ struct IntegerSetAttributeStorage;
struct TypeAttributeStorage;
struct SplatElementsAttributeStorage;
struct DenseElementsAttributeStorage;
struct DenseIntElementsAttributeStorage;
struct DenseFPElementsAttributeStorage;
struct OpaqueElementsAttributeStorage;
struct SparseElementsAttributeStorage;
@ -516,23 +514,39 @@ public:
/// or floating-point values. Each value is expected to be the same bitwidth
/// of the element type of 'type'. 'type' must be a vector or tensor with
/// static shape.
template <typename ShapeT, typename T>
static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
static_assert(std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value,
"expected integer or floating point element type");
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
template <typename T, typename = typename std::enable_if<
std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value>::type>
static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
const char *data = reinterpret_cast<const char *>(values.data());
return getRawIntOrFloat(type,
ArrayRef<char>(data, values.size() * sizeof(T)),
/*isInt=*/std::numeric_limits<T>::is_integer);
return getRawIntOrFloat(
type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
/*isInt=*/std::numeric_limits<T>::is_integer);
}
/// Constructs a dense integer elements attribute from a single element.
template <typename T, typename = typename std::enable_if<
std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value>::type>
static DenseElementsAttr get(const ShapedType &type, T value) {
return get(type, llvm::makeArrayRef(value));
}
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
/// Constructs a dense float elements attribute from an array of APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
//===--------------------------------------------------------------------===//
// Value Querying
//===--------------------------------------------------------------------===//
@ -540,8 +554,16 @@ public:
/// Return the raw storage data held by this attribute.
ArrayRef<char> getRawData() const;
/// Returns the number of elements held by this attribute.
size_t size() const;
/// Returns the number of raw elements held by this attribute.
size_t rawSize() const;
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool isSplat() const;
/// If this attribute corresponds to a splat, then get the splat value.
/// Otherwise, return null.
Attribute getSplatValue() const;
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
@ -620,22 +642,25 @@ protected:
/// Raw element iterators for this attribute.
RawElementIterator raw_begin() const { return RawElementIterator(*this, 0); }
RawElementIterator raw_end() const {
return RawElementIterator(*this, size());
return RawElementIterator(*this, rawSize());
}
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
/// Get or create a new dense elements attribute instance with the given raw
/// data buffer. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
bool isSplat);
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type.
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
ArrayRef<char> data, bool isInt);
ArrayRef<char> data,
int64_t dataEltSize, bool isInt);
};
/// An attribute that represents a reference to a dense integer vector or tensor
@ -650,12 +675,6 @@ public:
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr
@ -693,12 +712,6 @@ public:
using DenseElementsAttr::get;
using DenseElementsAttr::getValues;
/// Constructs a dense float elements attribute from an array of APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseFPElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
/// Gets the float value of each of the dense elements.
void getValues(SmallVectorImpl<APFloat> &values) const;
@ -776,7 +789,7 @@ public:
using Base::Base;
/// 'type' must be a vector or tensor with static shape.
static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices,
static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
DenseElementsAttr values);
DenseIntElementsAttr getIndices() const;

View File

@ -49,7 +49,7 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
const UniformQuantizedValueConverter &converter) {
// Convert to corresponding quantized value attributes.
SmallVector<APInt, 8> quantValues;
quantValues.reserve(realFPElementsAttr.size());
quantValues.reserve(realFPElementsAttr.rawSize());
for (APFloat realVal : realFPElementsAttr) {
quantValues.push_back(converter.quantizeFloatToInt(realVal));
}

View File

@ -747,6 +747,11 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
return;
}
// If this is a splat, make sure to print all of the elements.
// TODO: This should be removed when the parser supports dense splats.
if (attr.isSplat())
elements.resize(type.getNumElements(), elements.front());
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.

View File

@ -359,24 +359,147 @@ struct SplatElementsAttributeStorage : public AttributeStorage {
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::pair<Type, ArrayRef<char>>;
struct KeyTy {
KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
bool isSplat = false)
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
DenseElementsAttributeStorage(Type ty, ArrayRef<char> data,
/// The type of the dense elements.
ShapedType type;
/// The raw buffer for the data storage.
ArrayRef<char> data;
/// The computed hash code for the storage data.
llvm::hash_code hashCode;
/// A boolean that indicates if this data is a splat or not.
bool isSplat;
};
DenseElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
bool isSplat = false)
: AttributeStorage(ty), data(data), isSplat(isSplat) {}
/// Key equality and hash functions.
/// Compare this storage instance with the provided key.
bool operator==(const KeyTy &key) const {
return key == KeyTy(getType(), data);
if (key.type != getType())
return false;
// For boolean splats we need to explicitly check that the first bit is the
// same. Boolean values are packed at the bit level, and even though a splat
// is detected the rest of the bits in the first byte may differ from the
// splat value.
if (key.type.getElementTypeBitWidth() == 1) {
if (key.isSplat != isSplat)
return false;
if (isSplat)
return (key.data.front() & 1) == data.front();
}
// Otherwise, we can default to just checking the data.
return key.data == data;
}
/// Construct a key from a shaped type, raw data buffer, and a flag that
/// signals if the data is already known to be a splat. Callers to this
/// function are expected to tag preknown splat values when possible, e.g. one
/// element shapes.
static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
// Handle an empty storage instance.
if (data.empty())
return KeyTy(ty, data, 0);
// If the data is already known to be a splat, the key hash value is
// directly the data buffer.
if (isKnownSplat)
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
// Otherwise, we need to check if the data corresponds to a splat or not.
// Handle the simple case of only one element.
size_t numElements = ty.getNumElements();
assert(numElements != 1 && "splat of 1 element should already be detected");
// Handle boolean values directly as they are packed to 1-bit.
size_t elementWidth = ty.getElementTypeBitWidth();
if (elementWidth == 1)
return getKeyForBoolData(ty, data, numElements);
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
if (ty.getElementType().isBF16())
elementWidth = 64;
// Non 1-bit dense elements are padded to 8-bits.
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
assert(((data.size() / storageSize) == numElements) &&
"data does not hold expected number of elements");
// Create the initial hash value with just the first element.
auto firstElt = data.take_front(storageSize);
auto hashVal = llvm::hash_value(firstElt);
// Check to see if this storage represents a splat. If it doesn't then
// combine the hash for the data starting with the first non splat element.
for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
if (memcmp(data.data(), &data[i], storageSize))
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
// Otherwise, this is a splat so just return the hash of the first element.
return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
}
/// Construct a key with a set of boolean data.
static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
size_t numElements) {
ArrayRef<char> splatData = data;
bool splatValue = splatData.front() & 1;
// Helper functor to generate a KeyTy for a boolean splat value.
auto generateSplatKey = [=] {
return KeyTy(ty, data.take_front(1),
llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
/*isSplat=*/true);
};
// Handle the case where the potential splat value is 1 and the number of
// elements is non 8-bit aligned.
size_t numOddElements = numElements % CHAR_BIT;
if (splatValue && numOddElements != 0) {
// Check that all bits are set in the last value.
char lastElt = splatData.back();
if (lastElt != llvm::maskTrailingOnes<char>(numOddElements))
return KeyTy(ty, data, llvm::hash_value(data));
// If this is the only element, the data is known to be a splat.
if (splatData.size() == 1)
return generateSplatKey();
splatData = splatData.drop_back();
}
// Check that the data buffer corresponds to a splat.
return llvm::is_splat(splatData) ? generateSplatKey()
: KeyTy(ty, data, llvm::hash_value(data));
}
/// Hash the key for the storage.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.type, key.hashCode);
}
/// Construct a new storage instance.
static DenseElementsAttributeStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
// If the data buffer is non-empty, we copy it into the allocator.
ArrayRef<char> data = allocator.copyInto(key.second);
ArrayRef<char> data = allocator.copyInto(key.data);
// If this is a boolean splat, make sure only the first bit is used.
if (key.isSplat && key.type.getElementTypeBitWidth() == 1)
const_cast<char &>(data.front()) &= 1;
return new (allocator.allocate<DenseElementsAttributeStorage>())
DenseElementsAttributeStorage(key.first, data);
DenseElementsAttributeStorage(key.type, data, key.isSplat);
}
ArrayRef<char> data;

View File

@ -488,7 +488,7 @@ SplatElementsAttr SplatElementsAttr::mapValues(
}
//===----------------------------------------------------------------------===//
// RawElementIterator
// DenseElementAttr Utilities
//===----------------------------------------------------------------------===//
static size_t getDenseElementBitwidth(Type eltType) {
@ -547,6 +547,14 @@ static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
return result;
}
/// Returns if 'values' corresponds to a splat, i.e. one element, or has the
/// same element count as 'type'.
template <typename Values>
static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
return (values.size() == 1) ||
(type.getNumElements() == static_cast<int64_t>(values.size()));
}
/// Constructs a new iterator.
DenseElementsAttr::RawElementIterator::RawElementIterator(
DenseElementsAttr attr, size_t index)
@ -567,30 +575,30 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(type.getElementType().isIntOrFloat() &&
"expected int or float element type");
assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
"expected 'values' to contain the same number of elements as 'type'");
assert(hasSameElementsOrSplat(type, values));
auto eltType = type.getElementType();
size_t bitWidth = getDenseElementBitwidth(eltType);
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
// Compress the attribute values into a character buffer.
SmallVector<char, 8> data(storageBitWidth * type.getNumElements());
SmallVector<char, 8> data((storageBitWidth / CHAR_BIT) * values.size());
APInt intVal;
for (unsigned i = 0, e = values.size(); i < e; ++i) {
assert(eltType == values[i].getType() &&
"expected attribute value to have element type");
switch (eltType.getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
case StandardTypes::F32:
case StandardTypes::F64:
assert(eltType == values[i].cast<FloatAttr>().getType() &&
"expected attribute value to have element type");
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
break;
case StandardTypes::Integer:
assert(eltType == values[i].cast<IntegerAttr>().getType() &&
"expected attribute value to have element type");
intVal = values[i].cast<IntegerAttr>().getValue();
intVal = values[i].isa<BoolAttr>()
? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
: values[i].cast<IntegerAttr>().getValue();
break;
default:
llvm_unreachable("unexpected element type");
@ -599,58 +607,83 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
"expected value to have same bitwidth as element type");
writeBits(data.data(), i * storageBitWidth, intVal);
}
return getRaw(type, data);
return getRaw(type, data, /*isSplat=*/(values.size() == 1));
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<bool> values) {
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
assert(hasSameElementsOrSplat(type, values));
assert(type.getElementType().isInteger(1));
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
for (int i = 0, e = values.size(); i != e; ++i)
writeBits(buff.data(), i, llvm::APInt(1, values[i]));
return getRaw(type, buff);
setBit(buff.data(), i, values[i]);
return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
}
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'.
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
assert(type.getElementType().isa<IntegerType>());
return getRaw(type, values);
}
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
// element type of 'type'.
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APFloat> values) {
assert(type.getElementType().isa<FloatType>());
// Convert the APFloat values to APInt and create a dense elements attribute.
std::vector<APInt> intValues(values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i)
intValues[i] = values[i].bitcastToAPInt();
return getRaw(type, intValues);
}
// Constructs a dense elements attribute from an array of raw APInt values.
// Each APInt value is expected to have the same bitwidth as the element type
// of 'type'.
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
"expected 'values' to contain the same number of elements as 'type'");
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<APInt> values) {
assert(hasSameElementsOrSplat(type, values));
size_t bitWidth = getDenseElementBitwidth(type.getElementType());
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
std::vector<char> elementData(bitWidth * values.size());
std::vector<char> elementData((storageBitWidth / CHAR_BIT) * values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i) {
assert(values[i].getBitWidth() == bitWidth);
writeBits(elementData.data(), i * storageBitWidth, values[i]);
}
return getRaw(type, elementData);
return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
}
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data) {
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
data.size() * APInt::APINT_WORD_SIZE) &&
"Input data bit size should be larger than that type requires");
ArrayRef<char> data, bool isSplat) {
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
data);
data, isSplat);
}
/// Overload of the 'getRaw' method that asserts that the given type is of
/// integer type.
/// integer type. This method is used to verify type invariants that the
/// templatized 'get' method cannot.
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
int64_t dataEltSize,
bool isInt) {
assert(isInt ? type.getElementType().isa<IntegerType>()
: type.getElementType().isa<FloatType>());
return getRaw(type, data);
assert((dataEltSize * CHAR_BIT) == type.getElementTypeBitWidth());
int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements());
return getRaw(type, data, /*isSplat=*/numElements == 1);
}
/// Return the raw storage data held by this attribute.
@ -658,8 +691,29 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(impl)->data;
}
/// Returns the number of elements held by this attribute.
size_t DenseElementsAttr::size() const { return getType().getNumElements(); }
/// Returns the number of raw elements held by this attribute.
size_t DenseElementsAttr::rawSize() const {
return isSplat() ? 1 : getType().getNumElements();
}
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
/// If this attribute corresponds to a splat, then get the splat value.
/// Otherwise, return null.
Attribute DenseElementsAttr::getSplatValue() const {
if (!isSplat())
return Attribute();
auto elementType = getType().getElementType();
if (elementType.isa<IntegerType>())
return IntegerAttr::get(elementType, *raw_begin());
if (auto fType = elementType.dyn_cast<FloatType>())
return FloatAttr::get(elementType,
APFloat(fType.getFloatSemantics(), *raw_begin()));
llvm_unreachable("unexpected element type");
}
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
@ -677,6 +731,10 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
if (shape[i] <= static_cast<int64_t>(index[i]))
return Attribute();
// If this is a splat, return the splat value directly.
if (isSplat())
return getSplatValue();
// Reduce the provided multidimensional index into a 1D index.
uint64_t valueIndex = 0;
uint64_t dimMultiplier = 1;
@ -688,9 +746,9 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
// Return the element stored at the 1D index.
auto elementType = getType().getElementType();
size_t bitWidth = getDenseElementBitwidth(elementType);
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
size_t storageWidth = getDenseElementStorageWidth(bitWidth);
APInt rawValueData =
readBits(getRawData().data(), valueIndex * storageBitWidth, bitWidth);
readBits(getRawData().data(), valueIndex * storageWidth, bitWidth);
// Convert the raw value data to an attribute value.
if (elementType.isa<IntegerType>())
@ -739,7 +797,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
"expected the same element type");
assert(newType.getNumElements() == curType.getNumElements() &&
"expected the same number of elements");
return getRaw(newType, getRawData());
return getRaw(newType, getRawData(), isSplat());
}
DenseElementsAttr DenseElementsAttr::mapValues(
@ -758,16 +816,8 @@ DenseElementsAttr DenseElementsAttr::mapValues(
// DenseIntElementsAttr
//===----------------------------------------------------------------------===//
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'.
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
}
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
values.reserve(size());
values.reserve(rawSize());
values.assign(raw_begin(), raw_end());
}
@ -808,7 +858,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData);
return getRaw(newArrayType, elementData);
return getRaw(newArrayType, elementData, isSplat());
}
/// Method for supporting type inquiry through isa, cast and dyn_cast.
@ -827,20 +877,8 @@ DenseFPElementsAttr::ElementIterator::ElementIterator(
std::function<APFloat(const APInt &)>>(
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
// Constructs a dense float elements attribute from an array of APFloat
// values. Each APFloat value is expected to have the same bitwidth as the
// element type of 'type'.
DenseFPElementsAttr DenseFPElementsAttr::get(ShapedType type,
ArrayRef<APFloat> values) {
// Convert the APFloat values to APInt and create a dense elements attribute.
std::vector<APInt> intValues(values.size());
for (unsigned i = 0, e = values.size(); i != e; ++i)
intValues[i] = values[i].bitcastToAPInt();
return DenseElementsAttr::get(type, intValues).cast<DenseFPElementsAttr>();
}
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
values.reserve(size());
values.reserve(rawSize());
values.assign(begin(), end());
}
@ -851,7 +889,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData);
return getRaw(newArrayType, elementData);
return getRaw(newArrayType, elementData, isSplat());
}
/// Iterator access to the float element values.
@ -907,7 +945,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
//===----------------------------------------------------------------------===//
SparseElementsAttr SparseElementsAttr::get(ShapedType type,
DenseIntElementsAttr indices,
DenseElementsAttr indices,
DenseElementsAttr values) {
assert(indices.getType().getElementType().isInteger(64) &&
"expected sparse indices to be 64-bit integer values");
@ -915,7 +953,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
indices, values);
indices.cast<DenseIntElementsAttr>(), values);
}
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
@ -935,12 +973,34 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
if (rank != index.size())
return Attribute();
/// Return an attribute corresponding to '0' for the element type.
auto getZeroAttr = [=]() -> Attribute {
auto eltType = type.getElementType();
if (eltType.isa<FloatType>())
return FloatAttr::get(eltType, 0);
assert(eltType.isa<IntegerType>() && "unexpected element type");
return IntegerAttr::get(eltType, 0);
};
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
// as a 1-D index array.
auto sparseIndices = getIndices();
const uint64_t *sparseIndexValues =
reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
// Check to see if the indices are a splat.
if (sparseIndices.isSplat()) {
// If the index is also not a splat of the index value, we know that the
// value is zero.
auto splatIndex = *sparseIndexValues;
if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
return getZeroAttr();
// If the indices are a splat, we also expect the values to be a splat.
assert(getValues().isSplat() && "expected splat values");
return getValues().getSplatValue();
}
// Build a mapping between known indices and the offset of the stored element.
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
@ -950,13 +1010,8 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
// Look for the provided index key within the mapped indices. If the provided
// index is not found, then return a zero attribute.
auto it = mappedIndices.find(index);
if (it == mappedIndices.end()) {
auto eltType = type.getElementType();
if (eltType.isa<FloatType>())
return FloatAttr::get(eltType, 0);
assert(eltType.isa<IntegerType>() && "unexpected element type");
return IntegerAttr::get(eltType, 0);
}
if (it == mappedIndices.end())
return getZeroAttr();
// Otherwise, return the held sparse value element.
return getValues().getValue(it->second);

View File

@ -97,6 +97,13 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
}
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
if (denseAttr.isSplat()) {
auto *child = getLLVMConstant(vectorType->getElementType(),
denseAttr.getSplatValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(),
child);
}
SmallVector<llvm::Constant *, 8> constants;
uint64_t numElements = vectorType->getNumElements();
constants.reserve(numElements);

View File

@ -0,0 +1,136 @@
//===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::detail;
template <typename EltTy>
static void testSplat(Type eltType, const EltTy &splatElt) {
VectorType shape = VectorType::get({2, 1}, eltType);
// Check that the generated splat is the same for 1 element and N elements.
DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
EXPECT_TRUE(splat.isSplat());
auto detectedSplat =
DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
EXPECT_EQ(detectedSplat, splat);
}
namespace {
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
VectorType shape = VectorType::get({2, 2}, boolTy);
// Check that splat is automatically detected for boolean values.
/// True.
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
EXPECT_TRUE(trueSplat.isSplat());
/// False.
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
EXPECT_TRUE(falseSplat.isSplat());
EXPECT_NE(falseSplat, trueSplat);
/// Detect and handle splat within 8 elements (bool values are bit-packed).
/// True.
auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
EXPECT_EQ(detectedSplat, trueSplat);
/// False.
detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
EXPECT_EQ(detectedSplat, falseSplat);
}
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr size_t boolCount = 56;
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
VectorType shape = VectorType::get({boolCount}, boolTy);
// Check that splat is automatically detected for boolean values.
/// True.
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
EXPECT_TRUE(trueSplat.isSplat());
EXPECT_TRUE(falseSplat.isSplat());
/// Detect that the large boolean arrays are properly splatted.
/// True.
SmallVector<bool, 64> trueValues(boolCount, true);
auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
EXPECT_EQ(detectedSplat, trueSplat);
/// False.
SmallVector<bool, 64> falseValues(boolCount, false);
detectedSplat = DenseElementsAttr::get(shape, falseValues);
EXPECT_EQ(detectedSplat, falseSplat);
}
TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
MLIRContext context;
constexpr size_t intWidth = 19;
IntegerType intTy = IntegerType::get(intWidth, &context);
APInt value(intWidth, 10);
testSplat(intTy, value);
}
TEST(DenseSplatTest, Int32Splat) {
MLIRContext context;
IntegerType intTy = IntegerType::get(32, &context);
int value = 64;
testSplat(intTy, value);
}
TEST(DenseSplatTest, IntAttrSplat) {
MLIRContext context;
IntegerType intTy = IntegerType::get(85, &context);
Attribute value = IntegerAttr::get(intTy, 109);
testSplat(intTy, value);
}
TEST(DenseSplatTest, F32Splat) {
MLIRContext context;
FloatType floatTy = FloatType::getF32(&context);
float value = 10.0;
testSplat(floatTy, value);
}
TEST(DenseSplatTest, F64Splat) {
MLIRContext context;
FloatType floatTy = FloatType::getF64(&context);
double value = 10.0;
testSplat(floatTy, APFloat(value));
}
TEST(DenseSplatTest, FloatAttrSplat) {
MLIRContext context;
FloatType floatTy = FloatType::getBF16(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
testSplat(floatTy, value);
}
} // end namespace

View File

@ -1,4 +1,5 @@
add_mlir_unittest(MLIRIRTests
AttributeTest.cpp
DialectTest.cpp
OperationSupportTest.cpp
)