forked from OSchip/llvm-project
Refactor DenseElementsAttr to support auto-splatting the dense data on construction. This essentially means that we always auto-detect splat data and only store the minimum amount of data necessary. Support for parsing dense splats, and removing SplatElementsAttr(now that it is redundant) will come in followup cls
PiperOrigin-RevId: 252720561
This commit is contained in:
parent
5da741f671
commit
d8cd96bc8b
|
@ -47,8 +47,6 @@ struct IntegerSetAttributeStorage;
|
|||
struct TypeAttributeStorage;
|
||||
struct SplatElementsAttributeStorage;
|
||||
struct DenseElementsAttributeStorage;
|
||||
struct DenseIntElementsAttributeStorage;
|
||||
struct DenseFPElementsAttributeStorage;
|
||||
struct OpaqueElementsAttributeStorage;
|
||||
struct SparseElementsAttributeStorage;
|
||||
|
||||
|
@ -516,23 +514,39 @@ public:
|
|||
/// or floating-point values. Each value is expected to be the same bitwidth
|
||||
/// of the element type of 'type'. 'type' must be a vector or tensor with
|
||||
/// static shape.
|
||||
template <typename ShapeT, typename T>
|
||||
static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
|
||||
static_assert(std::numeric_limits<T>::is_integer ||
|
||||
llvm::is_one_of<T, float, double>::value,
|
||||
"expected integer or floating point element type");
|
||||
|
||||
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||
assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::numeric_limits<T>::is_integer ||
|
||||
llvm::is_one_of<T, float, double>::value>::type>
|
||||
static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
|
||||
const char *data = reinterpret_cast<const char *>(values.data());
|
||||
return getRawIntOrFloat(type,
|
||||
ArrayRef<char>(data, values.size() * sizeof(T)),
|
||||
/*isInt=*/std::numeric_limits<T>::is_integer);
|
||||
return getRawIntOrFloat(
|
||||
type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T),
|
||||
/*isInt=*/std::numeric_limits<T>::is_integer);
|
||||
}
|
||||
|
||||
/// Constructs a dense integer elements attribute from a single element.
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::numeric_limits<T>::is_integer ||
|
||||
llvm::is_one_of<T, float, double>::value>::type>
|
||||
static DenseElementsAttr get(const ShapedType &type, T value) {
|
||||
return get(type, llvm::makeArrayRef(value));
|
||||
}
|
||||
|
||||
/// Overload of the above 'get' method that is specialized for boolean values.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'. 'type' must be a vector or tensor with static
|
||||
/// shape.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Constructs a dense float elements attribute from an array of APFloat
|
||||
/// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'. 'type' must be a vector or tensor with static
|
||||
/// shape.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Value Querying
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -540,8 +554,16 @@ public:
|
|||
/// Return the raw storage data held by this attribute.
|
||||
ArrayRef<char> getRawData() const;
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
size_t size() const;
|
||||
/// Returns the number of raw elements held by this attribute.
|
||||
size_t rawSize() const;
|
||||
|
||||
/// Returns if this attribute corresponds to a splat, i.e. if all element
|
||||
/// values are the same.
|
||||
bool isSplat() const;
|
||||
|
||||
/// If this attribute corresponds to a splat, then get the splat value.
|
||||
/// Otherwise, return null.
|
||||
Attribute getSplatValue() const;
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
|
@ -620,22 +642,25 @@ protected:
|
|||
/// Raw element iterators for this attribute.
|
||||
RawElementIterator raw_begin() const { return RawElementIterator(*this, 0); }
|
||||
RawElementIterator raw_end() const {
|
||||
return RawElementIterator(*this, size());
|
||||
return RawElementIterator(*this, rawSize());
|
||||
}
|
||||
|
||||
/// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
/// Each APInt value is expected to have the same bitwidth as the element type
|
||||
/// of 'type'. 'type' must be a vector or tensor with static shape.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Get or create a new dense elements attribute instance with the given raw
|
||||
/// data buffer. 'type' must be a vector or tensor with static shape.
|
||||
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
|
||||
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
|
||||
bool isSplat);
|
||||
|
||||
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||
/// integer or floating-point type.
|
||||
/// integer or floating-point type. This method is used to verify type
|
||||
/// invariants that the templatized 'get' method cannot.
|
||||
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
|
||||
ArrayRef<char> data, bool isInt);
|
||||
ArrayRef<char> data,
|
||||
int64_t dataEltSize, bool isInt);
|
||||
};
|
||||
|
||||
/// An attribute that represents a reference to a dense integer vector or tensor
|
||||
|
@ -650,12 +675,6 @@ public:
|
|||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'. 'type' must be a vector or tensor with static
|
||||
/// shape.
|
||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
||||
/// constructing the DenseElementsAttr given the new element type.
|
||||
DenseElementsAttr
|
||||
|
@ -693,12 +712,6 @@ public:
|
|||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
|
||||
/// Constructs a dense float elements attribute from an array of APFloat
|
||||
/// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'. 'type' must be a vector or tensor with static
|
||||
/// shape.
|
||||
static DenseFPElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
|
||||
|
||||
/// Gets the float value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APFloat> &values) const;
|
||||
|
||||
|
@ -776,7 +789,7 @@ public:
|
|||
using Base::Base;
|
||||
|
||||
/// 'type' must be a vector or tensor with static shape.
|
||||
static SparseElementsAttr get(ShapedType type, DenseIntElementsAttr indices,
|
||||
static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
|
||||
DenseIntElementsAttr getIndices() const;
|
||||
|
|
|
@ -49,7 +49,7 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
|
|||
const UniformQuantizedValueConverter &converter) {
|
||||
// Convert to corresponding quantized value attributes.
|
||||
SmallVector<APInt, 8> quantValues;
|
||||
quantValues.reserve(realFPElementsAttr.size());
|
||||
quantValues.reserve(realFPElementsAttr.rawSize());
|
||||
for (APFloat realVal : realFPElementsAttr) {
|
||||
quantValues.push_back(converter.quantizeFloatToInt(realVal));
|
||||
}
|
||||
|
|
|
@ -747,6 +747,11 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
|||
return;
|
||||
}
|
||||
|
||||
// If this is a splat, make sure to print all of the elements.
|
||||
// TODO: This should be removed when the parser supports dense splats.
|
||||
if (attr.isSplat())
|
||||
elements.resize(type.getNumElements(), elements.front());
|
||||
|
||||
// We use a mixed-radix counter to iterate through the shape. When we bump a
|
||||
// non-least-significant digit, we emit a close bracket. When we next emit an
|
||||
// element we re-open all closed brackets.
|
||||
|
|
|
@ -359,24 +359,147 @@ struct SplatElementsAttributeStorage : public AttributeStorage {
|
|||
|
||||
/// An attribute representing a reference to a dense vector or tensor object.
|
||||
struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::pair<Type, ArrayRef<char>>;
|
||||
struct KeyTy {
|
||||
KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
|
||||
bool isSplat = false)
|
||||
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
|
||||
|
||||
DenseElementsAttributeStorage(Type ty, ArrayRef<char> data,
|
||||
/// The type of the dense elements.
|
||||
ShapedType type;
|
||||
|
||||
/// The raw buffer for the data storage.
|
||||
ArrayRef<char> data;
|
||||
|
||||
/// The computed hash code for the storage data.
|
||||
llvm::hash_code hashCode;
|
||||
|
||||
/// A boolean that indicates if this data is a splat or not.
|
||||
bool isSplat;
|
||||
};
|
||||
|
||||
DenseElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
|
||||
bool isSplat = false)
|
||||
: AttributeStorage(ty), data(data), isSplat(isSplat) {}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
/// Compare this storage instance with the provided key.
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(getType(), data);
|
||||
if (key.type != getType())
|
||||
return false;
|
||||
|
||||
// For boolean splats we need to explicitly check that the first bit is the
|
||||
// same. Boolean values are packed at the bit level, and even though a splat
|
||||
// is detected the rest of the bits in the first byte may differ from the
|
||||
// splat value.
|
||||
if (key.type.getElementTypeBitWidth() == 1) {
|
||||
if (key.isSplat != isSplat)
|
||||
return false;
|
||||
if (isSplat)
|
||||
return (key.data.front() & 1) == data.front();
|
||||
}
|
||||
|
||||
// Otherwise, we can default to just checking the data.
|
||||
return key.data == data;
|
||||
}
|
||||
|
||||
/// Construct a key from a shaped type, raw data buffer, and a flag that
|
||||
/// signals if the data is already known to be a splat. Callers to this
|
||||
/// function are expected to tag preknown splat values when possible, e.g. one
|
||||
/// element shapes.
|
||||
static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
|
||||
// Handle an empty storage instance.
|
||||
if (data.empty())
|
||||
return KeyTy(ty, data, 0);
|
||||
|
||||
// If the data is already known to be a splat, the key hash value is
|
||||
// directly the data buffer.
|
||||
if (isKnownSplat)
|
||||
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
|
||||
|
||||
// Otherwise, we need to check if the data corresponds to a splat or not.
|
||||
|
||||
// Handle the simple case of only one element.
|
||||
size_t numElements = ty.getNumElements();
|
||||
assert(numElements != 1 && "splat of 1 element should already be detected");
|
||||
|
||||
// Handle boolean values directly as they are packed to 1-bit.
|
||||
size_t elementWidth = ty.getElementTypeBitWidth();
|
||||
if (elementWidth == 1)
|
||||
return getKeyForBoolData(ty, data, numElements);
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
if (ty.getElementType().isBF16())
|
||||
elementWidth = 64;
|
||||
|
||||
// Non 1-bit dense elements are padded to 8-bits.
|
||||
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
|
||||
assert(((data.size() / storageSize) == numElements) &&
|
||||
"data does not hold expected number of elements");
|
||||
|
||||
// Create the initial hash value with just the first element.
|
||||
auto firstElt = data.take_front(storageSize);
|
||||
auto hashVal = llvm::hash_value(firstElt);
|
||||
|
||||
// Check to see if this storage represents a splat. If it doesn't then
|
||||
// combine the hash for the data starting with the first non splat element.
|
||||
for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
|
||||
if (memcmp(data.data(), &data[i], storageSize))
|
||||
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
|
||||
|
||||
// Otherwise, this is a splat so just return the hash of the first element.
|
||||
return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
|
||||
}
|
||||
|
||||
/// Construct a key with a set of boolean data.
|
||||
static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
|
||||
size_t numElements) {
|
||||
ArrayRef<char> splatData = data;
|
||||
bool splatValue = splatData.front() & 1;
|
||||
|
||||
// Helper functor to generate a KeyTy for a boolean splat value.
|
||||
auto generateSplatKey = [=] {
|
||||
return KeyTy(ty, data.take_front(1),
|
||||
llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
|
||||
/*isSplat=*/true);
|
||||
};
|
||||
|
||||
// Handle the case where the potential splat value is 1 and the number of
|
||||
// elements is non 8-bit aligned.
|
||||
size_t numOddElements = numElements % CHAR_BIT;
|
||||
if (splatValue && numOddElements != 0) {
|
||||
// Check that all bits are set in the last value.
|
||||
char lastElt = splatData.back();
|
||||
if (lastElt != llvm::maskTrailingOnes<char>(numOddElements))
|
||||
return KeyTy(ty, data, llvm::hash_value(data));
|
||||
|
||||
// If this is the only element, the data is known to be a splat.
|
||||
if (splatData.size() == 1)
|
||||
return generateSplatKey();
|
||||
splatData = splatData.drop_back();
|
||||
}
|
||||
|
||||
// Check that the data buffer corresponds to a splat.
|
||||
return llvm::is_splat(splatData) ? generateSplatKey()
|
||||
: KeyTy(ty, data, llvm::hash_value(data));
|
||||
}
|
||||
|
||||
/// Hash the key for the storage.
|
||||
static llvm::hash_code hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key.type, key.hashCode);
|
||||
}
|
||||
|
||||
/// Construct a new storage instance.
|
||||
static DenseElementsAttributeStorage *
|
||||
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
||||
// If the data buffer is non-empty, we copy it into the allocator.
|
||||
ArrayRef<char> data = allocator.copyInto(key.second);
|
||||
ArrayRef<char> data = allocator.copyInto(key.data);
|
||||
|
||||
// If this is a boolean splat, make sure only the first bit is used.
|
||||
if (key.isSplat && key.type.getElementTypeBitWidth() == 1)
|
||||
const_cast<char &>(data.front()) &= 1;
|
||||
|
||||
return new (allocator.allocate<DenseElementsAttributeStorage>())
|
||||
DenseElementsAttributeStorage(key.first, data);
|
||||
DenseElementsAttributeStorage(key.type, data, key.isSplat);
|
||||
}
|
||||
|
||||
ArrayRef<char> data;
|
||||
|
|
|
@ -488,7 +488,7 @@ SplatElementsAttr SplatElementsAttr::mapValues(
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RawElementIterator
|
||||
// DenseElementAttr Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static size_t getDenseElementBitwidth(Type eltType) {
|
||||
|
@ -547,6 +547,14 @@ static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Returns if 'values' corresponds to a splat, i.e. one element, or has the
|
||||
/// same element count as 'type'.
|
||||
template <typename Values>
|
||||
static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
|
||||
return (values.size() == 1) ||
|
||||
(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||
}
|
||||
|
||||
/// Constructs a new iterator.
|
||||
DenseElementsAttr::RawElementIterator::RawElementIterator(
|
||||
DenseElementsAttr attr, size_t index)
|
||||
|
@ -567,30 +575,30 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|||
ArrayRef<Attribute> values) {
|
||||
assert(type.getElementType().isIntOrFloat() &&
|
||||
"expected int or float element type");
|
||||
assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
|
||||
"expected 'values' to contain the same number of elements as 'type'");
|
||||
assert(hasSameElementsOrSplat(type, values));
|
||||
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = getDenseElementBitwidth(eltType);
|
||||
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
||||
|
||||
// Compress the attribute values into a character buffer.
|
||||
SmallVector<char, 8> data(storageBitWidth * type.getNumElements());
|
||||
SmallVector<char, 8> data((storageBitWidth / CHAR_BIT) * values.size());
|
||||
APInt intVal;
|
||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||
assert(eltType == values[i].getType() &&
|
||||
"expected attribute value to have element type");
|
||||
|
||||
switch (eltType.getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64:
|
||||
assert(eltType == values[i].cast<FloatAttr>().getType() &&
|
||||
"expected attribute value to have element type");
|
||||
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
|
||||
break;
|
||||
case StandardTypes::Integer:
|
||||
assert(eltType == values[i].cast<IntegerAttr>().getType() &&
|
||||
"expected attribute value to have element type");
|
||||
intVal = values[i].cast<IntegerAttr>().getValue();
|
||||
intVal = values[i].isa<BoolAttr>()
|
||||
? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
|
||||
: values[i].cast<IntegerAttr>().getValue();
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
|
@ -599,58 +607,83 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|||
"expected value to have same bitwidth as element type");
|
||||
writeBits(data.data(), i * storageBitWidth, intVal);
|
||||
}
|
||||
return getRaw(type, data);
|
||||
return getRaw(type, data, /*isSplat=*/(values.size() == 1));
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<bool> values) {
|
||||
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||
assert(hasSameElementsOrSplat(type, values));
|
||||
assert(type.getElementType().isInteger(1));
|
||||
|
||||
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
|
||||
for (int i = 0, e = values.size(); i != e; ++i)
|
||||
writeBits(buff.data(), i, llvm::APInt(1, values[i]));
|
||||
return getRaw(type, buff);
|
||||
setBit(buff.data(), i, values[i]);
|
||||
return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
|
||||
}
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'.
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(type.getElementType().isa<IntegerType>());
|
||||
return getRaw(type, values);
|
||||
}
|
||||
|
||||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APFloat> values) {
|
||||
assert(type.getElementType().isa<FloatType>());
|
||||
|
||||
// Convert the APFloat values to APInt and create a dense elements attribute.
|
||||
std::vector<APInt> intValues(values.size());
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i)
|
||||
intValues[i] = values[i].bitcastToAPInt();
|
||||
return getRaw(type, intValues);
|
||||
}
|
||||
|
||||
// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
// Each APInt value is expected to have the same bitwidth as the element type
|
||||
// of 'type'.
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(static_cast<int64_t>(values.size()) == type.getNumElements() &&
|
||||
"expected 'values' to contain the same number of elements as 'type'");
|
||||
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(hasSameElementsOrSplat(type, values));
|
||||
|
||||
size_t bitWidth = getDenseElementBitwidth(type.getElementType());
|
||||
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
||||
std::vector<char> elementData(bitWidth * values.size());
|
||||
std::vector<char> elementData((storageBitWidth / CHAR_BIT) * values.size());
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
assert(values[i].getBitWidth() == bitWidth);
|
||||
writeBits(elementData.data(), i * storageBitWidth, values[i]);
|
||||
}
|
||||
return getRaw(type, elementData);
|
||||
return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
||||
ArrayRef<char> data) {
|
||||
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
|
||||
data.size() * APInt::APINT_WORD_SIZE) &&
|
||||
"Input data bit size should be larger than that type requires");
|
||||
ArrayRef<char> data, bool isSplat) {
|
||||
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
|
||||
"type must be ranked tensor or vector");
|
||||
assert(type.hasStaticShape() && "type must have static shape");
|
||||
return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
|
||||
data);
|
||||
data, isSplat);
|
||||
}
|
||||
|
||||
/// Overload of the 'getRaw' method that asserts that the given type is of
|
||||
/// integer type.
|
||||
/// integer type. This method is used to verify type invariants that the
|
||||
/// templatized 'get' method cannot.
|
||||
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
|
||||
ArrayRef<char> data,
|
||||
int64_t dataEltSize,
|
||||
bool isInt) {
|
||||
assert(isInt ? type.getElementType().isa<IntegerType>()
|
||||
: type.getElementType().isa<FloatType>());
|
||||
return getRaw(type, data);
|
||||
assert((dataEltSize * CHAR_BIT) == type.getElementTypeBitWidth());
|
||||
|
||||
int64_t numElements = data.size() / dataEltSize;
|
||||
assert(numElements == 1 || numElements == type.getNumElements());
|
||||
return getRaw(type, data, /*isSplat=*/numElements == 1);
|
||||
}
|
||||
|
||||
/// Return the raw storage data held by this attribute.
|
||||
|
@ -658,8 +691,29 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
|
|||
return static_cast<ImplType *>(impl)->data;
|
||||
}
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
size_t DenseElementsAttr::size() const { return getType().getNumElements(); }
|
||||
/// Returns the number of raw elements held by this attribute.
|
||||
size_t DenseElementsAttr::rawSize() const {
|
||||
return isSplat() ? 1 : getType().getNumElements();
|
||||
}
|
||||
|
||||
/// Returns if this attribute corresponds to a splat, i.e. if all element
|
||||
/// values are the same.
|
||||
bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
|
||||
|
||||
/// If this attribute corresponds to a splat, then get the splat value.
|
||||
/// Otherwise, return null.
|
||||
Attribute DenseElementsAttr::getSplatValue() const {
|
||||
if (!isSplat())
|
||||
return Attribute();
|
||||
|
||||
auto elementType = getType().getElementType();
|
||||
if (elementType.isa<IntegerType>())
|
||||
return IntegerAttr::get(elementType, *raw_begin());
|
||||
if (auto fType = elementType.dyn_cast<FloatType>())
|
||||
return FloatAttr::get(elementType,
|
||||
APFloat(fType.getFloatSemantics(), *raw_begin()));
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
|
@ -677,6 +731,10 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
if (shape[i] <= static_cast<int64_t>(index[i]))
|
||||
return Attribute();
|
||||
|
||||
// If this is a splat, return the splat value directly.
|
||||
if (isSplat())
|
||||
return getSplatValue();
|
||||
|
||||
// Reduce the provided multidimensional index into a 1D index.
|
||||
uint64_t valueIndex = 0;
|
||||
uint64_t dimMultiplier = 1;
|
||||
|
@ -688,9 +746,9 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
// Return the element stored at the 1D index.
|
||||
auto elementType = getType().getElementType();
|
||||
size_t bitWidth = getDenseElementBitwidth(elementType);
|
||||
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
||||
size_t storageWidth = getDenseElementStorageWidth(bitWidth);
|
||||
APInt rawValueData =
|
||||
readBits(getRawData().data(), valueIndex * storageBitWidth, bitWidth);
|
||||
readBits(getRawData().data(), valueIndex * storageWidth, bitWidth);
|
||||
|
||||
// Convert the raw value data to an attribute value.
|
||||
if (elementType.isa<IntegerType>())
|
||||
|
@ -739,7 +797,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
|
|||
"expected the same element type");
|
||||
assert(newType.getNumElements() == curType.getNumElements() &&
|
||||
"expected the same number of elements");
|
||||
return getRaw(newType, getRawData());
|
||||
return getRaw(newType, getRawData(), isSplat());
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
|
@ -758,16 +816,8 @@ DenseElementsAttr DenseElementsAttr::mapValues(
|
|||
// DenseIntElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of APInt
|
||||
/// values. Each APInt value is expected to have the same bitwidth as the
|
||||
/// element type of 'type'.
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
|
||||
}
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
||||
values.reserve(size());
|
||||
values.reserve(rawSize());
|
||||
values.assign(raw_begin(), raw_end());
|
||||
}
|
||||
|
||||
|
@ -808,7 +858,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
|
|||
auto newArrayType =
|
||||
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
||||
|
||||
return getRaw(newArrayType, elementData);
|
||||
return getRaw(newArrayType, elementData, isSplat());
|
||||
}
|
||||
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
|
@ -827,20 +877,8 @@ DenseFPElementsAttr::ElementIterator::ElementIterator(
|
|||
std::function<APFloat(const APInt &)>>(
|
||||
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
|
||||
|
||||
// Constructs a dense float elements attribute from an array of APFloat
|
||||
// values. Each APFloat value is expected to have the same bitwidth as the
|
||||
// element type of 'type'.
|
||||
DenseFPElementsAttr DenseFPElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APFloat> values) {
|
||||
// Convert the APFloat values to APInt and create a dense elements attribute.
|
||||
std::vector<APInt> intValues(values.size());
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i)
|
||||
intValues[i] = values[i].bitcastToAPInt();
|
||||
return DenseElementsAttr::get(type, intValues).cast<DenseFPElementsAttr>();
|
||||
}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
|
||||
values.reserve(size());
|
||||
values.reserve(rawSize());
|
||||
values.assign(begin(), end());
|
||||
}
|
||||
|
||||
|
@ -851,7 +889,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
|
|||
auto newArrayType =
|
||||
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
||||
|
||||
return getRaw(newArrayType, elementData);
|
||||
return getRaw(newArrayType, elementData, isSplat());
|
||||
}
|
||||
|
||||
/// Iterator access to the float element values.
|
||||
|
@ -907,7 +945,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SparseElementsAttr SparseElementsAttr::get(ShapedType type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
assert(indices.getType().getElementType().isInteger(64) &&
|
||||
"expected sparse indices to be 64-bit integer values");
|
||||
|
@ -915,7 +953,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
|
|||
"type must be ranked tensor or vector");
|
||||
assert(type.hasStaticShape() && "type must have static shape");
|
||||
return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
|
||||
indices, values);
|
||||
indices.cast<DenseIntElementsAttr>(), values);
|
||||
}
|
||||
|
||||
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
||||
|
@ -935,12 +973,34 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
if (rank != index.size())
|
||||
return Attribute();
|
||||
|
||||
/// Return an attribute corresponding to '0' for the element type.
|
||||
auto getZeroAttr = [=]() -> Attribute {
|
||||
auto eltType = type.getElementType();
|
||||
if (eltType.isa<FloatType>())
|
||||
return FloatAttr::get(eltType, 0);
|
||||
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
||||
return IntegerAttr::get(eltType, 0);
|
||||
};
|
||||
|
||||
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
||||
// as a 1-D index array.
|
||||
auto sparseIndices = getIndices();
|
||||
const uint64_t *sparseIndexValues =
|
||||
reinterpret_cast<const uint64_t *>(sparseIndices.getRawData().data());
|
||||
|
||||
// Check to see if the indices are a splat.
|
||||
if (sparseIndices.isSplat()) {
|
||||
// If the index is also not a splat of the index value, we know that the
|
||||
// value is zero.
|
||||
auto splatIndex = *sparseIndexValues;
|
||||
if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
|
||||
return getZeroAttr();
|
||||
|
||||
// If the indices are a splat, we also expect the values to be a splat.
|
||||
assert(getValues().isSplat() && "expected splat values");
|
||||
return getValues().getSplatValue();
|
||||
}
|
||||
|
||||
// Build a mapping between known indices and the offset of the stored element.
|
||||
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
|
||||
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
|
||||
|
@ -950,13 +1010,8 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
// Look for the provided index key within the mapped indices. If the provided
|
||||
// index is not found, then return a zero attribute.
|
||||
auto it = mappedIndices.find(index);
|
||||
if (it == mappedIndices.end()) {
|
||||
auto eltType = type.getElementType();
|
||||
if (eltType.isa<FloatType>())
|
||||
return FloatAttr::get(eltType, 0);
|
||||
assert(eltType.isa<IntegerType>() && "unexpected element type");
|
||||
return IntegerAttr::get(eltType, 0);
|
||||
}
|
||||
if (it == mappedIndices.end())
|
||||
return getZeroAttr();
|
||||
|
||||
// Otherwise, return the held sparse value element.
|
||||
return getValues().getValue(it->second);
|
||||
|
|
|
@ -97,6 +97,13 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
|
|||
}
|
||||
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
|
||||
auto *vectorType = cast<llvm::VectorType>(llvmType);
|
||||
if (denseAttr.isSplat()) {
|
||||
auto *child = getLLVMConstant(vectorType->getElementType(),
|
||||
denseAttr.getSplatValue(), loc);
|
||||
return llvm::ConstantVector::getSplat(vectorType->getNumElements(),
|
||||
child);
|
||||
}
|
||||
|
||||
SmallVector<llvm::Constant *, 8> constants;
|
||||
uint64_t numElements = vectorType->getNumElements();
|
||||
constants.reserve(numElements);
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
//===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
template <typename EltTy>
|
||||
static void testSplat(Type eltType, const EltTy &splatElt) {
|
||||
VectorType shape = VectorType::get({2, 1}, eltType);
|
||||
|
||||
// Check that the generated splat is the same for 1 element and N elements.
|
||||
DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
|
||||
EXPECT_TRUE(splat.isSplat());
|
||||
|
||||
auto detectedSplat =
|
||||
DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
|
||||
EXPECT_EQ(detectedSplat, splat);
|
||||
}
|
||||
|
||||
namespace {
|
||||
TEST(DenseSplatTest, BoolSplat) {
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
VectorType shape = VectorType::get({2, 2}, boolTy);
|
||||
|
||||
// Check that splat is automatically detected for boolean values.
|
||||
/// True.
|
||||
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
|
||||
EXPECT_TRUE(trueSplat.isSplat());
|
||||
/// False.
|
||||
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
|
||||
EXPECT_TRUE(falseSplat.isSplat());
|
||||
EXPECT_NE(falseSplat, trueSplat);
|
||||
|
||||
/// Detect and handle splat within 8 elements (bool values are bit-packed).
|
||||
/// True.
|
||||
auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
|
||||
EXPECT_EQ(detectedSplat, trueSplat);
|
||||
/// False.
|
||||
detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
|
||||
EXPECT_EQ(detectedSplat, falseSplat);
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, LargeBoolSplat) {
|
||||
constexpr size_t boolCount = 56;
|
||||
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
VectorType shape = VectorType::get({boolCount}, boolTy);
|
||||
|
||||
// Check that splat is automatically detected for boolean values.
|
||||
/// True.
|
||||
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
|
||||
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
|
||||
EXPECT_TRUE(trueSplat.isSplat());
|
||||
EXPECT_TRUE(falseSplat.isSplat());
|
||||
|
||||
/// Detect that the large boolean arrays are properly splatted.
|
||||
/// True.
|
||||
SmallVector<bool, 64> trueValues(boolCount, true);
|
||||
auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
|
||||
EXPECT_EQ(detectedSplat, trueSplat);
|
||||
/// False.
|
||||
SmallVector<bool, 64> falseValues(boolCount, false);
|
||||
detectedSplat = DenseElementsAttr::get(shape, falseValues);
|
||||
EXPECT_EQ(detectedSplat, falseSplat);
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, OddIntSplat) {
|
||||
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
||||
MLIRContext context;
|
||||
constexpr size_t intWidth = 19;
|
||||
IntegerType intTy = IntegerType::get(intWidth, &context);
|
||||
APInt value(intWidth, 10);
|
||||
|
||||
testSplat(intTy, value);
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, Int32Splat) {
|
||||
MLIRContext context;
|
||||
IntegerType intTy = IntegerType::get(32, &context);
|
||||
int value = 64;
|
||||
|
||||
testSplat(intTy, value);
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, IntAttrSplat) {
|
||||
MLIRContext context;
|
||||
IntegerType intTy = IntegerType::get(85, &context);
|
||||
Attribute value = IntegerAttr::get(intTy, 109);
|
||||
|
||||
testSplat(intTy, value);
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, F32Splat) {
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getF32(&context);
|
||||
float value = 10.0;
|
||||
|
||||
testSplat(floatTy, value);
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, F64Splat) {
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getF64(&context);
|
||||
double value = 10.0;
|
||||
|
||||
testSplat(floatTy, APFloat(value));
|
||||
}
|
||||
|
||||
TEST(DenseSplatTest, FloatAttrSplat) {
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getBF16(&context);
|
||||
Attribute value = FloatAttr::get(floatTy, 10.0);
|
||||
|
||||
testSplat(floatTy, value);
|
||||
}
|
||||
} // end namespace
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_unittest(MLIRIRTests
|
||||
AttributeTest.cpp
|
||||
DialectTest.cpp
|
||||
OperationSupportTest.cpp
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue