From f21896f2c6dc6f4c2c3d0f192f7fefd178f5d5f7 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 12 May 2022 05:32:16 +0100 Subject: [PATCH] [DenseElementAttr] Simplify the public API for creating these. Instead of requiring the client to compute the "isSplat" bit, compute it internally. This makes the logic more consistent and defines away a lot of "elements.size()==1" in the clients. This addresses Issue #55185 Differential Revision: https://reviews.llvm.org/D125447 --- mlir/include/mlir/IR/BuiltinAttributes.h | 7 +- mlir/include/mlir/IR/BuiltinAttributes.td | 19 +++-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 6 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +- mlir/lib/IR/BuiltinAttributes.cpp | 87 ++++++++++++++--------- mlir/lib/Parser/AttributeParser.cpp | 5 +- 6 files changed, 71 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 4371a1cb088f..85f6d3f4e638 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -193,13 +193,8 @@ public: /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to /// the linear order of the shape type from MSB to LSB, padded to on the /// right. - /// - /// If `isSplatBuffer` is true, then the raw buffer should contain a - /// single element (or for the case of 1-bit, a single byte of 0 or 255), - /// which will be used to construct a splat. static DenseElementsAttr getFromRawBuffer(ShapedType type, - ArrayRef rawBuffer, - bool isSplatBuffer); + ArrayRef rawBuffer); /// Returns true if the given buffer is a valid raw buffer for the given type. /// `detectedSplat` is set if the buffer is valid and represents a splat diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 1fbc7eb18e5d..19c8a07b94cd 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -236,19 +236,27 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< /// values. Each APFloat value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. + /// + /// If the `values` array only has a single element, then this constructs + /// splat of that value. static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, bool isSplat); + ArrayRef values); /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element /// type of 'type'. 'type' must be a vector or tensor with static shape. + /// + /// If the `values` array only has a single element, then this constructs + /// splat of that value. static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, bool isSplat); + ArrayRef values); /// Get or create a new dense elements attribute instance with the given raw /// data buffer. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, ArrayRef data, - bool isSplat); + /// + /// If the `values` array only has a single element, then this constructs + /// splat of that value. + static DenseElementsAttr getRaw(ShapedType type, ArrayRef data); /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the @@ -264,7 +272,6 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); - public: }]; let genAccessors = 0; @@ -308,7 +315,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< let builders = [ AttrBuilderWithInferredContext<(ins "ShapedType":$type, "ArrayRef":$values), [{ - return $_get(type.getContext(), type, values, + return $_get(type.getContext(), type, values, /* isSplat */(values.size() == 1)); }]>, ]; diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index aa498b2c1e18..759b708952e2 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -348,11 +348,9 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, rawBufferSize); bool isSplat = false; if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, - isSplat)) { + isSplat)) return mlirAttributeGetNull(); - } - return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp, - isSplat)); + return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); } MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index ee5acdc34dc0..932973a13c21 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -837,7 +837,7 @@ struct FoldReshapeWithConstant : OpRewritePattern { if (!attr || !attr.isSplat()) return failure(); DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( - reshapeOp.getResultType(), attr.getRawData(), true); + reshapeOp.getResultType(), attr.getRawData()); rewriter.replaceOpWithNewOp(reshapeOp, newAttr); return success(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 0004fe90fe87..1ecdf183d9ec 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -713,8 +713,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, "expected value to have same bitwidth as element type"); writeBits(data.data(), i * storageBitWidth, intVal); } - return DenseIntOrFPElementsAttr::getRaw(type, data, - /*isSplat=*/(values.size() == 1)); + + // Handle the special encoding of splat of bool. + if (values.size() == 1 && values[0].getType().isInteger(1)) + data[0] = data[0] ? -1 : 0; + + return DenseIntOrFPElementsAttr::getRaw(type, data); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -723,10 +727,22 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, assert(type.getElementType().isInteger(1)); std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); - for (int i = 0, e = values.size(); i != e; ++i) - setBit(buff.data(), i, values[i]); - return DenseIntOrFPElementsAttr::getRaw(type, buff, - /*isSplat=*/(values.size() == 1)); + + if (!values.empty()) { + bool isSplat = true; + bool firstValue = values[0]; + for (int i = 0, e = values.size(); i != e; ++i) { + isSplat &= values[i] == firstValue; + setBit(buff.data(), i, values[i]); + } + + if (isSplat) { // special encoding for splat. + buff.resize(1); + buff[0] = values[0] ? -1 : 0; + } + } + + return DenseIntOrFPElementsAttr::getRaw(type, buff); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -743,8 +759,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, assert(type.getElementType().isIntOrIndex()); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, - /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef> values) { @@ -754,8 +769,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; ArrayRef intVals(reinterpret_cast(values.data()), values.size() * 2); - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, - /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); } // Constructs a dense float elements attribute from an array of APFloat @@ -766,8 +780,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, assert(type.getElementType().isa()); assert(hasSameElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, - /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -778,17 +791,15 @@ DenseElementsAttr::get(ShapedType type, ArrayRef apVals(reinterpret_cast(values.data()), values.size() * 2); size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, - /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); } /// Construct a dense elements attribute from a raw buffer representing the /// data for this attribute. Users should generally not use this methods as /// the expected buffer format may not be a form the user expects. -DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, - ArrayRef rawBuffer, - bool isSplatBuffer) { - return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); +DenseElementsAttr +DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef rawBuffer) { + return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); } /// Returns true if the given buffer is a valid raw buffer for the given type. @@ -964,7 +975,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { "expected the same element type"); assert(newType.getNumElements() == curType.getNumElements() && "expected the same number of elements"); - return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); + return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); } DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { @@ -976,7 +987,7 @@ DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { assert(newType.getElementType() == curType.getElementType() && "expected the same element type"); - return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true); + return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); } /// Return a new DenseElementsAttr that has the same data as the current @@ -993,7 +1004,7 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"); return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), - getRawData(), isSplat()); + getRawData()); } DenseElementsAttr @@ -1027,13 +1038,18 @@ int64_t DenseElementsAttr::getNumElements() const { template static void writeAPIntsToBuffer(size_t storageWidth, std::vector &data, APRangeT &&values) { - data.resize(llvm::divideCeil(storageWidth * llvm::size(values), CHAR_BIT)); + size_t numValues = llvm::size(values); + data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT)); size_t offset = 0; for (auto it = values.begin(), e = values.end(); it != e; ++it, offset += storageWidth) { assert((*it).getBitWidth() <= storageWidth); writeBits(data.data(), offset, *it); } + + // Handle the special encoding of splat of a boolean. + if (numValues == 1 && (*values.begin()).getBitWidth() == 1) + data[0] = data[0] ? -1 : 0; } /// Constructs a dense elements attribute from an array of raw APFloat values. @@ -1041,12 +1057,11 @@ static void writeAPIntsToBuffer(size_t storageWidth, std::vector &data, /// type of 'type'. 'type' must be a vector or tensor with static shape. DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, - bool isSplat) { + ArrayRef values) { std::vector data; auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); - return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); + return DenseIntOrFPElementsAttr::getRaw(type, data); } /// Constructs a dense elements attribute from an array of raw APInt values. @@ -1054,19 +1069,21 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, /// of 'type'. DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, - bool isSplat) { + ArrayRef values) { std::vector data; writeAPIntsToBuffer(storageWidth, data, values); - return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); + return DenseIntOrFPElementsAttr::getRaw(type, data); } DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, - ArrayRef data, - bool isSplat) { + ArrayRef data) { assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); + bool isSplat = false; + bool isValid = isValidRawBuffer(type, data, isSplat); + assert(isValid); + (void)isValid; return Base::get(type.getContext(), type, data, isSplat); } @@ -1084,7 +1101,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, int64_t numElements = data.size() / dataEltSize; assert(numElements == 1 || numElements == type.getNumElements()); - return getRaw(type, data, /*isSplat=*/numElements == 1); + return getRaw(type, data); } /// Overload of the 'getRaw' method that asserts that the given type is of @@ -1099,7 +1116,8 @@ DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t numElements = data.size() / dataEltSize; assert(numElements == 1 || numElements == type.getNumElements()); - return getRaw(type, data, /*isSplat=*/numElements == 1); + (void)numElements; + return getRaw(type, data); } void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( @@ -1212,7 +1230,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues( auto newArrayType = mappingHelper(mapping, *this, getType(), newElementType, elementData); - return getRaw(newArrayType, elementData, isSplat()); + return getRaw(newArrayType, elementData); } /// Method for supporting type inquiry through isa, cast and dyn_cast. @@ -1230,8 +1248,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues( llvm::SmallVector elementData; auto newArrayType = mappingHelper(mapping, *this, getType(), newElementType, elementData); - - return getRaw(newArrayType, elementData, isSplat()); + return getRaw(newArrayType, elementData); } /// Method for supporting type inquiry through isa, cast and dyn_cast. diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp index 0161a3c15120..3618962fa1d5 100644 --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -717,11 +717,10 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, MutableArrayRef convRawData(outDataVec); DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( rawData, convRawData, type); - return DenseElementsAttr::getFromRawBuffer(type, convRawData, - detectedSplat); + return DenseElementsAttr::getFromRawBuffer(type, convRawData); } - return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); + return DenseElementsAttr::getFromRawBuffer(type, rawData); } ParseResult TensorLiteralParser::parseElement() {