[DenseElementAttr] Simplify the public API for creating these.

Instead of requiring the client to compute the "isSplat" bit,
compute it internally.  This makes the logic more consistent
and defines away a lot of "elements.size()==1" in the clients.

This addresses Issue #55185

Differential Revision: https://reviews.llvm.org/D125447
This commit is contained in:
Chris Lattner 2022-05-12 05:32:16 +01:00
parent 6822ed035f
commit f21896f2c6
6 changed files with 71 additions and 55 deletions

View File

@ -193,13 +193,8 @@ public:
/// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to
/// the linear order of the shape type from MSB to LSB, padded to on the /// the linear order of the shape type from MSB to LSB, padded to on the
/// right. /// right.
///
/// If `isSplatBuffer` is true, then the raw buffer should contain a
/// single element (or for the case of 1-bit, a single byte of 0 or 255),
/// which will be used to construct a splat.
static DenseElementsAttr getFromRawBuffer(ShapedType type, static DenseElementsAttr getFromRawBuffer(ShapedType type,
ArrayRef<char> rawBuffer, ArrayRef<char> rawBuffer);
bool isSplatBuffer);
/// Returns true if the given buffer is a valid raw buffer for the given type. /// Returns true if the given buffer is a valid raw buffer for the given type.
/// `detectedSplat` is set if the buffer is valid and represents a splat /// `detectedSplat` is set if the buffer is valid and represents a splat

View File

@ -236,19 +236,27 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
/// values. Each APFloat value is expected to have the same bitwidth as the /// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static /// element type of 'type'. 'type' must be a vector or tensor with static
/// shape. /// shape.
///
/// If the `values` array only has a single element, then this constructs
/// splat of that value.
static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
ArrayRef<APFloat> values, bool isSplat); ArrayRef<APFloat> values);
/// Constructs a dense elements attribute from an array of raw APInt values. /// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element /// Each APInt value is expected to have the same bitwidth as the element
/// type of 'type'. 'type' must be a vector or tensor with static shape. /// type of 'type'. 'type' must be a vector or tensor with static shape.
///
/// If the `values` array only has a single element, then this constructs
/// splat of that value.
static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
ArrayRef<APInt> values, bool isSplat); ArrayRef<APInt> values);
/// Get or create a new dense elements attribute instance with the given raw /// Get or create a new dense elements attribute instance with the given raw
/// data buffer. 'type' must be a vector or tensor with static shape. /// data buffer. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data, ///
bool isSplat); /// If the `values` array only has a single element, then this constructs
/// splat of that value.
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
/// Overload of the raw 'get' method that asserts that the given type is of /// Overload of the raw 'get' method that asserts that the given type is of
/// complex type. This method is used to verify type invariants that the /// complex type. This method is used to verify type invariants that the
@ -264,7 +272,6 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
ArrayRef<char> data, ArrayRef<char> data,
int64_t dataEltSize, bool isInt, int64_t dataEltSize, bool isInt,
bool isSigned); bool isSigned);
public: public:
}]; }];
let genAccessors = 0; let genAccessors = 0;

View File

@ -348,11 +348,9 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
rawBufferSize); rawBufferSize);
bool isSplat = false; bool isSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
isSplat)) { isSplat))
return mlirAttributeGetNull(); return mlirAttributeGetNull();
} return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp,
isSplat));
} }
MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,

View File

@ -837,7 +837,7 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
if (!attr || !attr.isSplat()) if (!attr || !attr.isSplat())
return failure(); return failure();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
reshapeOp.getResultType(), attr.getRawData(), true); reshapeOp.getResultType(), attr.getRawData());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr); rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
return success(); return success();
} }

View File

@ -713,8 +713,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
"expected value to have same bitwidth as element type"); "expected value to have same bitwidth as element type");
writeBits(data.data(), i * storageBitWidth, intVal); writeBits(data.data(), i * storageBitWidth, intVal);
} }
return DenseIntOrFPElementsAttr::getRaw(type, data,
/*isSplat=*/(values.size() == 1)); // Handle the special encoding of splat of bool.
if (values.size() == 1 && values[0].getType().isInteger(1))
data[0] = data[0] ? -1 : 0;
return DenseIntOrFPElementsAttr::getRaw(type, data);
} }
DenseElementsAttr DenseElementsAttr::get(ShapedType type, DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@ -723,10 +727,22 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(type.getElementType().isInteger(1)); assert(type.getElementType().isInteger(1));
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
for (int i = 0, e = values.size(); i != e; ++i)
setBit(buff.data(), i, values[i]); if (!values.empty()) {
return DenseIntOrFPElementsAttr::getRaw(type, buff, bool isSplat = true;
/*isSplat=*/(values.size() == 1)); bool firstValue = values[0];
for (int i = 0, e = values.size(); i != e; ++i) {
isSplat &= values[i] == firstValue;
setBit(buff.data(), i, values[i]);
}
if (isSplat) { // special encoding for splat.
buff.resize(1);
buff[0] = values[0] ? -1 : 0;
}
}
return DenseIntOrFPElementsAttr::getRaw(type, buff);
} }
DenseElementsAttr DenseElementsAttr::get(ShapedType type, DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@ -743,8 +759,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(type.getElementType().isIntOrIndex()); assert(type.getElementType().isIntOrIndex());
assert(hasSameElementsOrSplat(type, values)); assert(hasSameElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
/*isSplat=*/(values.size() == 1));
} }
DenseElementsAttr DenseElementsAttr::get(ShapedType type, DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<std::complex<APInt>> values) { ArrayRef<std::complex<APInt>> values) {
@ -754,8 +769,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
values.size() * 2); values.size() * 2);
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
/*isSplat=*/(values.size() == 1));
} }
// Constructs a dense float elements attribute from an array of APFloat // Constructs a dense float elements attribute from an array of APFloat
@ -766,8 +780,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(type.getElementType().isa<FloatType>()); assert(type.getElementType().isa<FloatType>());
assert(hasSameElementsOrSplat(type, values)); assert(hasSameElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
/*isSplat=*/(values.size() == 1));
} }
DenseElementsAttr DenseElementsAttr
DenseElementsAttr::get(ShapedType type, DenseElementsAttr::get(ShapedType type,
@ -778,17 +791,15 @@ DenseElementsAttr::get(ShapedType type,
ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
values.size() * 2); values.size() * 2);
size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
/*isSplat=*/(values.size() == 1));
} }
/// Construct a dense elements attribute from a raw buffer representing the /// Construct a dense elements attribute from a raw buffer representing the
/// data for this attribute. Users should generally not use this methods as /// data for this attribute. Users should generally not use this methods as
/// the expected buffer format may not be a form the user expects. /// the expected buffer format may not be a form the user expects.
DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, DenseElementsAttr
ArrayRef<char> rawBuffer, DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
bool isSplatBuffer) { return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
} }
/// Returns true if the given buffer is a valid raw buffer for the given type. /// Returns true if the given buffer is a valid raw buffer for the given type.
@ -964,7 +975,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
"expected the same element type"); "expected the same element type");
assert(newType.getNumElements() == curType.getNumElements() && assert(newType.getNumElements() == curType.getNumElements() &&
"expected the same number of elements"); "expected the same number of elements");
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
} }
DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
@ -976,7 +987,7 @@ DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
assert(newType.getElementType() == curType.getElementType() && assert(newType.getElementType() == curType.getElementType() &&
"expected the same element type"); "expected the same element type");
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true); return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
} }
/// Return a new DenseElementsAttr that has the same data as the current /// Return a new DenseElementsAttr that has the same data as the current
@ -993,7 +1004,7 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
getDenseElementBitWidth(curElType) && getDenseElementBitWidth(curElType) &&
"expected element types with the same bitwidth"); "expected element types with the same bitwidth");
return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
getRawData(), isSplat()); getRawData());
} }
DenseElementsAttr DenseElementsAttr
@ -1027,13 +1038,18 @@ int64_t DenseElementsAttr::getNumElements() const {
template <typename APRangeT> template <typename APRangeT>
static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
APRangeT &&values) { APRangeT &&values) {
data.resize(llvm::divideCeil(storageWidth * llvm::size(values), CHAR_BIT)); size_t numValues = llvm::size(values);
data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
size_t offset = 0; size_t offset = 0;
for (auto it = values.begin(), e = values.end(); it != e; for (auto it = values.begin(), e = values.end(); it != e;
++it, offset += storageWidth) { ++it, offset += storageWidth) {
assert((*it).getBitWidth() <= storageWidth); assert((*it).getBitWidth() <= storageWidth);
writeBits(data.data(), offset, *it); writeBits(data.data(), offset, *it);
} }
// Handle the special encoding of splat of a boolean.
if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
data[0] = data[0] ? -1 : 0;
} }
/// Constructs a dense elements attribute from an array of raw APFloat values. /// Constructs a dense elements attribute from an array of raw APFloat values.
@ -1041,12 +1057,11 @@ static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
/// type of 'type'. 'type' must be a vector or tensor with static shape. /// type of 'type'. 'type' must be a vector or tensor with static shape.
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
size_t storageWidth, size_t storageWidth,
ArrayRef<APFloat> values, ArrayRef<APFloat> values) {
bool isSplat) {
std::vector<char> data; std::vector<char> data;
auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); return DenseIntOrFPElementsAttr::getRaw(type, data);
} }
/// Constructs a dense elements attribute from an array of raw APInt values. /// Constructs a dense elements attribute from an array of raw APInt values.
@ -1054,19 +1069,21 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
/// of 'type'. /// of 'type'.
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
size_t storageWidth, size_t storageWidth,
ArrayRef<APInt> values, ArrayRef<APInt> values) {
bool isSplat) {
std::vector<char> data; std::vector<char> data;
writeAPIntsToBuffer(storageWidth, data, values); writeAPIntsToBuffer(storageWidth, data, values);
return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); return DenseIntOrFPElementsAttr::getRaw(type, data);
} }
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data, ArrayRef<char> data) {
bool isSplat) {
assert((type.isa<RankedTensorType, VectorType>()) && assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector"); "type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape"); assert(type.hasStaticShape() && "type must have static shape");
bool isSplat = false;
bool isValid = isValidRawBuffer(type, data, isSplat);
assert(isValid);
(void)isValid;
return Base::get(type.getContext(), type, data, isSplat); return Base::get(type.getContext(), type, data, isSplat);
} }
@ -1084,7 +1101,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
int64_t numElements = data.size() / dataEltSize; int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements()); assert(numElements == 1 || numElements == type.getNumElements());
return getRaw(type, data, /*isSplat=*/numElements == 1); return getRaw(type, data);
} }
/// Overload of the 'getRaw' method that asserts that the given type is of /// Overload of the 'getRaw' method that asserts that the given type is of
@ -1099,7 +1116,8 @@ DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
int64_t numElements = data.size() / dataEltSize; int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements()); assert(numElements == 1 || numElements == type.getNumElements());
return getRaw(type, data, /*isSplat=*/numElements == 1); (void)numElements;
return getRaw(type, data);
} }
void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
@ -1212,7 +1230,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
auto newArrayType = auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData); mappingHelper(mapping, *this, getType(), newElementType, elementData);
return getRaw(newArrayType, elementData, isSplat()); return getRaw(newArrayType, elementData);
} }
/// Method for supporting type inquiry through isa, cast and dyn_cast. /// Method for supporting type inquiry through isa, cast and dyn_cast.
@ -1230,8 +1248,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
llvm::SmallVector<char, 8> elementData; llvm::SmallVector<char, 8> elementData;
auto newArrayType = auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData); mappingHelper(mapping, *this, getType(), newElementType, elementData);
return getRaw(newArrayType, elementData);
return getRaw(newArrayType, elementData, isSplat());
} }
/// Method for supporting type inquiry through isa, cast and dyn_cast. /// Method for supporting type inquiry through isa, cast and dyn_cast.

View File

@ -717,11 +717,10 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc,
MutableArrayRef<char> convRawData(outDataVec); MutableArrayRef<char> convRawData(outDataVec);
DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
rawData, convRawData, type); rawData, convRawData, type);
return DenseElementsAttr::getFromRawBuffer(type, convRawData, return DenseElementsAttr::getFromRawBuffer(type, convRawData);
detectedSplat);
} }
return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); return DenseElementsAttr::getFromRawBuffer(type, rawData);
} }
ParseResult TensorLiteralParser::parseElement() { ParseResult TensorLiteralParser::parseElement() {