[mlir] Fix DenseElementsAttr treatment of bool splat of "true"

Boolean splats currently can't roundtrip via the "raw" DenseElementsAttr
API. This is because internally we treat true splats in some cases as "1"(one bit set)
and in other cases as "0xFF"(all bits set). This commit cleans up this handling to
consistently use 0xFF (all bits set) as the value for a splat of true.

Differential Revision: https://reviews.llvm.org/D133743
This commit is contained in:
River Riddle 2022-09-12 18:42:08 -07:00
parent 34300ee369
commit 9e0900cbf1
2 changed files with 37 additions and 30 deletions

View File

@ -76,22 +76,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
/// Compare this storage instance with the provided key.
bool operator==(const KeyTy &key) const {
if (key.type != type)
return false;
// For boolean splats we need to explicitly check that the first bit is the
// same. Boolean values are packed at the bit level, and even though a splat
// is detected the rest of the bits in the first byte may differ from the
// splat value.
if (key.type.getElementType().isInteger(1)) {
if (key.isSplat != isSplat)
return false;
if (isSplat)
return (key.data.front() & 1) == data.front();
}
// Otherwise, we can default to just checking the data.
return key.data == data;
return key.type == type && key.data == data;
}
/// Construct a key from a shaped type, raw data buffer, and a flag that
@ -105,8 +90,12 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
// If the data is already known to be a splat, the key hash value is
// directly the data buffer.
if (isKnownSplat)
bool isBoolData = ty.getElementType().isInteger(1);
if (isKnownSplat) {
if (isBoolData)
return getKeyForSplatBoolData(ty, data[0] != 0);
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
}
// Otherwise, we need to check if the data corresponds to a splat or not.
@ -115,7 +104,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
assert(numElements != 1 && "splat of 1 element should already be detected");
// Handle boolean values directly as they are packed to 1-bit.
if (ty.getElementType().isInteger(1) == 1)
if (isBoolData)
return getKeyForBoolData(ty, data, numElements);
size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
@ -144,12 +133,9 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
ArrayRef<char> splatData = data;
bool splatValue = splatData.front() & 1;
// Helper functor to generate a KeyTy for a boolean splat value.
auto generateSplatKey = [=] {
return KeyTy(ty, data.take_front(1),
llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
/*isSplat=*/true);
};
// Check the simple case where the data matches the known splat value.
if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse))
return getKeyForSplatBoolData(ty, splatValue);
// Handle the case where the potential splat value is 1 and the number of
// elements is non 8-bit aligned.
@ -162,17 +148,24 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
// If this is the only element, the data is known to be a splat.
if (splatData.size() == 1)
return generateSplatKey();
return getKeyForSplatBoolData(ty, splatValue);
splatData = splatData.drop_back();
}
// Check that the data buffer corresponds to a splat of the proper mask.
char mask = splatValue ? ~0 : 0;
return llvm::all_of(splatData, [mask](char c) { return c == mask; })
? generateSplatKey()
? getKeyForSplatBoolData(ty, splatValue)
: KeyTy(ty, data, llvm::hash_value(data));
}
/// Return a key to use for a boolean splat of the given value.
static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) {
const char &splatData = splatValue ? kSplatTrue : kSplatFalse;
return KeyTy(type, splatData, llvm::hash_value(splatData),
/*isSplat=*/true);
}
/// Hash the key for the storage.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.type, key.hashCode);
@ -188,10 +181,6 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
char *rawData = reinterpret_cast<char *>(
allocator.allocate(data.size(), alignof(uint64_t)));
std::memcpy(rawData, data.data(), data.size());
// If this is a boolean splat, make sure only the first bit is used.
if (key.isSplat && key.type.getElementType().isInteger(1))
rawData[0] &= 1;
copy = ArrayRef<char>(rawData, data.size());
}
@ -200,6 +189,10 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
}
ArrayRef<char> data;
/// The values used to denote a boolean splat value.
static constexpr char kSplatTrue = ~0;
static constexpr char kSplatFalse = 0;
};
/// An attribute representing a reference to a dense vector or tensor object

View File

@ -58,6 +58,20 @@ TEST(DenseSplatTest, BoolSplat) {
detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
EXPECT_EQ(detectedSplat, falseSplat);
}
TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
// Check that splat booleans properly round trip via the raw API.
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
EXPECT_TRUE(trueSplat.isSplat());
DenseElementsAttr trueSplatFromRaw =
DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
EXPECT_TRUE(trueSplatFromRaw.isSplat());
EXPECT_EQ(trueSplat, trueSplatFromRaw);
}
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;