Remove the ability to directly construct a DenseElementsAttr with a raw character buffer. This made assumptions about how DenseElementsAttr structured its internal storage, which may change in the future. To replace the existing use cases, a few utility methods have been added:

* 'get' methods that allow constructing from an ArrayRef of integer or floating point values.
* A 'reshape' method to allow for changing the shape without changing the underlying data.

PiperOrigin-RevId: 252067898
This commit is contained in:
River Riddle 2019-06-07 09:57:29 -07:00 committed by Mehdi Amini
parent 62facfaf42
commit 0cadec8ae6
6 changed files with 81 additions and 41 deletions

View File

@ -81,8 +81,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
// auto oldType = constantOp.getType();
auto newType = rewriter.getTensorType(
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr =
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
auto newAttr = valueAttr.reshape(newType);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =

View File

@ -83,8 +83,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
// auto oldType = constantOp.getType();
auto newType = rewriter.getTensorType(
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr =
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
auto newAttr = valueAttr.reshape(newType);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =

View File

@ -502,16 +502,32 @@ class DenseElementsAttr
public:
using Base::Base;
/// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type. 'type' must be a vector or tensor
/// with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
/// Constructs a dense elements attribute from an array of element values.
/// Each element attribute value is expected to be an element of 'type'.
/// 'type' must be a vector or tensor with static shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
/// Constructs a dense integer elements attribute from an array of integer
/// or floating-point values. Each value is expected to be the same bitwidth
/// of the element type of 'type'. 'type' must be a vector or tensor with
/// static shape.
template <typename ShapeT, typename T>
static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
static_assert(std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double>::value,
"expected integer or floating point element type");
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
const char *data = reinterpret_cast<const char *>(values.data());
return getRawIntOrFloat(type,
ArrayRef<char>(data, values.size() * sizeof(T)),
/*isInt=*/std::numeric_limits<T>::is_integer);
}
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
/// Returns the number of elements held by this attribute.
size_t size() const;
@ -519,8 +535,14 @@ public:
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Return the held element values as Attributes in 'values'.
void getValues(SmallVectorImpl<Attribute> &values) const;
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr reshape(ShapedType newType);
/// Generates a new DenseElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This underlying type must be an DenseIntElementsAttr.
@ -600,6 +622,15 @@ protected:
return RawElementIterator(*this, size());
}
/// Get or create a new dense elements attribute instance with the given raw
/// data buffer. 'type' must be a vector or tensor with static shape.
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type.
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
ArrayRef<char> data, bool isInt);
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
@ -624,11 +655,6 @@ public:
/// shape.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
/// Constructs a dense integer elements attribute from an array of integer
/// values. Each value is expected to be within the bitwidth of the element
/// type of 'type'. 'type' must be a vector or tensor with static shape.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr

View File

@ -116,7 +116,6 @@ public:
FunctionAttr getFunctionAttr(Function *value);
FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
ElementsAttr getDenseIntElementsAttr(ShapedType type,

View File

@ -532,7 +532,8 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const {
// DenseElementsAttr
//===----------------------------------------------------------------------===//
DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data) {
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
data.size() * APInt::APINT_WORD_SIZE) &&
"Input data bit size should be larger than that type requires");
@ -543,6 +544,27 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
data);
}
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer type.
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
bool isInt) {
assert(isInt ? type.getElementType().isa<IntegerType>()
: type.getElementType().isa<FloatType>());
return getRaw(type, data);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<bool> values) {
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
assert(type.getElementType().isInteger(1));
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
for (int i = 0, e = values.size(); i != e; ++i)
writeBits(buff.data(), i, llvm::APInt(1, values[i]));
return getRaw(type, buff);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(type.getElementType().isIntOrFloat() &&
@ -579,7 +601,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
"expected value to have same bitwidth as element type");
writeBits(data.data(), i * storageBitWidth, intVal);
}
return get(type, data);
return getRaw(type, data);
}
/// Returns the number of elements held by this attribute.
@ -650,6 +672,22 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
llvm_unreachable("unexpected element type");
}
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
ShapedType curType = getType();
if (curType == newType)
return *this;
(void)curType;
assert(newType.getElementType() == curType.getElementType() &&
"expected the same element type");
assert(newType.getNumElements() == curType.getNumElements() &&
"expected the same number of elements");
return getRaw(newType, getRawData());
}
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
@ -681,7 +719,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
assert(values[i].getBitWidth() == bitWidth);
writeBits(elementData.data(), i * storageBitWidth, values[i]);
}
return get(type, elementData);
return getRaw(type, elementData);
}
/// Writes value to the bit position `bitPos` in array `rawData`.
@ -728,22 +766,6 @@ DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
}
/// Constructs a dense integer elements attribute from an array of integer
/// values. Each value is expected to be within the bitwidth of the element
/// type of 'type'.
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
ArrayRef<int64_t> values) {
auto eltType = type.getElementType();
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
// Convert the raw integer values to APInt.
SmallVector<APInt, 8> apIntValues;
apIntValues.reserve(values.size());
for (auto value : values)
apIntValues.emplace_back(APInt(bitWidth, value));
return get(type, apIntValues);
}
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
values.reserve(size());
values.assign(raw_begin(), raw_end());
@ -786,7 +808,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData);
return get(newArrayType, elementData);
return getRaw(newArrayType, elementData);
}
/// Method for supporting type inquiry through isa, cast and dyn_cast.
@ -829,7 +851,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
auto newArrayType =
mappingHelper(mapping, *this, getType(), newElementType, elementData);
return get(newArrayType, elementData);
return getRaw(newArrayType, elementData);
}
/// Iterator access to the float element values.

View File

@ -188,11 +188,6 @@ ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
return SplatElementsAttr::get(type, elt);
}
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
ArrayRef<char> data) {
return DenseElementsAttr::get(type, data);
}
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values) {
return DenseElementsAttr::get(type, values);