forked from OSchip/llvm-project
Remove the ability to directly construct a DenseElementsAttr with a raw character buffer. This made assumptions about how DenseElementsAttr structured its internal storage, which may change in the future. To replace the existing use cases, a few utility methods have been added:
* 'get' methods that allow constructing from an ArrayRef of integer or floating point values. * A 'reshape' method to allow for changing the shape without changing the underlying data. PiperOrigin-RevId: 252067898
This commit is contained in:
parent
62facfaf42
commit
0cadec8ae6
|
@ -81,8 +81,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
|||
// auto oldType = constantOp.getType();
|
||||
auto newType = rewriter.getTensorType(
|
||||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||
auto newAttr =
|
||||
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
|
||||
auto newAttr = valueAttr.reshape(newType);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else if (auto valueAttr =
|
||||
|
|
|
@ -83,8 +83,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
|||
// auto oldType = constantOp.getType();
|
||||
auto newType = rewriter.getTensorType(
|
||||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||
auto newAttr =
|
||||
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
|
||||
auto newAttr = valueAttr.reshape(newType);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else if (auto valueAttr =
|
||||
|
|
|
@ -502,16 +502,32 @@ class DenseElementsAttr
|
|||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// It assumes the elements in the input array have been truncated to the bits
|
||||
/// width specified by the element type. 'type' must be a vector or tensor
|
||||
/// with static shape.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
|
||||
|
||||
/// Constructs a dense elements attribute from an array of element values.
|
||||
/// Each element attribute value is expected to be an element of 'type'.
|
||||
/// 'type' must be a vector or tensor with static shape.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of integer
|
||||
/// or floating-point values. Each value is expected to be the same bitwidth
|
||||
/// of the element type of 'type'. 'type' must be a vector or tensor with
|
||||
/// static shape.
|
||||
template <typename ShapeT, typename T>
|
||||
static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
|
||||
static_assert(std::numeric_limits<T>::is_integer ||
|
||||
llvm::is_one_of<T, float, double>::value,
|
||||
"expected integer or floating point element type");
|
||||
|
||||
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||
assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
|
||||
const char *data = reinterpret_cast<const char *>(values.data());
|
||||
return getRawIntOrFloat(type,
|
||||
ArrayRef<char>(data, values.size() * sizeof(T)),
|
||||
/*isInt=*/std::numeric_limits<T>::is_integer);
|
||||
}
|
||||
|
||||
/// Overload of the above 'get' method that is specialized for boolean values.
|
||||
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
size_t size() const;
|
||||
|
||||
|
@ -519,8 +535,14 @@ public:
|
|||
/// element, then a null attribute is returned.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
||||
/// Return the held element values as Attributes in 'values'.
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
||||
/// Return a new DenseElementsAttr that has the same data as the current
|
||||
/// attribute, but has been reshaped to 'newType'. The new type must have the
|
||||
/// same total number of elements as well as element type.
|
||||
DenseElementsAttr reshape(ShapedType newType);
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each int value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This underlying type must be an DenseIntElementsAttr.
|
||||
|
@ -600,6 +622,15 @@ protected:
|
|||
return RawElementIterator(*this, size());
|
||||
}
|
||||
|
||||
/// Get or create a new dense elements attribute instance with the given raw
|
||||
/// data buffer. 'type' must be a vector or tensor with static shape.
|
||||
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
|
||||
|
||||
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||
/// integer or floating-point type.
|
||||
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
|
||||
ArrayRef<char> data, bool isInt);
|
||||
|
||||
/// Constructs a dense elements attribute from an array of raw APInt values.
|
||||
/// Each APInt value is expected to have the same bitwidth as the element type
|
||||
/// of 'type'. 'type' must be a vector or tensor with static shape.
|
||||
|
@ -624,11 +655,6 @@ public:
|
|||
/// shape.
|
||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of integer
|
||||
/// values. Each value is expected to be within the bitwidth of the element
|
||||
/// type of 'type'. 'type' must be a vector or tensor with static shape.
|
||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
||||
/// constructing the DenseElementsAttr given the new element type.
|
||||
DenseElementsAttr
|
||||
|
|
|
@ -116,7 +116,6 @@ public:
|
|||
FunctionAttr getFunctionAttr(Function *value);
|
||||
FunctionAttr getFunctionAttr(StringRef value);
|
||||
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values);
|
||||
ElementsAttr getDenseIntElementsAttr(ShapedType type,
|
||||
|
|
|
@ -532,7 +532,8 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const {
|
|||
// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
|
||||
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
||||
ArrayRef<char> data) {
|
||||
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
|
||||
data.size() * APInt::APINT_WORD_SIZE) &&
|
||||
"Input data bit size should be larger than that type requires");
|
||||
|
@ -543,6 +544,27 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
|
|||
data);
|
||||
}
|
||||
|
||||
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||
/// integer type.
|
||||
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
|
||||
ArrayRef<char> data,
|
||||
bool isInt) {
|
||||
assert(isInt ? type.getElementType().isa<IntegerType>()
|
||||
: type.getElementType().isa<FloatType>());
|
||||
return getRaw(type, data);
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<bool> values) {
|
||||
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||
assert(type.getElementType().isInteger(1));
|
||||
|
||||
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
|
||||
for (int i = 0, e = values.size(); i != e; ++i)
|
||||
writeBits(buff.data(), i, llvm::APInt(1, values[i]));
|
||||
return getRaw(type, buff);
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
assert(type.getElementType().isIntOrFloat() &&
|
||||
|
@ -579,7 +601,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|||
"expected value to have same bitwidth as element type");
|
||||
writeBits(data.data(), i * storageBitWidth, intVal);
|
||||
}
|
||||
return get(type, data);
|
||||
return getRaw(type, data);
|
||||
}
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
|
@ -650,6 +672,22 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
|
||||
/// Return a new DenseElementsAttr that has the same data as the current
|
||||
/// attribute, but has been reshaped to 'newType'. The new type must have the
|
||||
/// same total number of elements as well as element type.
|
||||
DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
|
||||
ShapedType curType = getType();
|
||||
if (curType == newType)
|
||||
return *this;
|
||||
|
||||
(void)curType;
|
||||
assert(newType.getElementType() == curType.getElementType() &&
|
||||
"expected the same element type");
|
||||
assert(newType.getNumElements() == curType.getNumElements() &&
|
||||
"expected the same number of elements");
|
||||
return getRaw(newType, getRawData());
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
|
@ -681,7 +719,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
|||
assert(values[i].getBitWidth() == bitWidth);
|
||||
writeBits(elementData.data(), i * storageBitWidth, values[i]);
|
||||
}
|
||||
return get(type, elementData);
|
||||
return getRaw(type, elementData);
|
||||
}
|
||||
|
||||
/// Writes value to the bit position `bitPos` in array `rawData`.
|
||||
|
@ -728,22 +766,6 @@ DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
|||
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
|
||||
}
|
||||
|
||||
/// Constructs a dense integer elements attribute from an array of integer
|
||||
/// values. Each value is expected to be within the bitwidth of the element
|
||||
/// type of 'type'.
|
||||
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
||||
ArrayRef<int64_t> values) {
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
|
||||
// Convert the raw integer values to APInt.
|
||||
SmallVector<APInt, 8> apIntValues;
|
||||
apIntValues.reserve(values.size());
|
||||
for (auto value : values)
|
||||
apIntValues.emplace_back(APInt(bitWidth, value));
|
||||
return get(type, apIntValues);
|
||||
}
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
||||
values.reserve(size());
|
||||
values.assign(raw_begin(), raw_end());
|
||||
|
@ -786,7 +808,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
|
|||
auto newArrayType =
|
||||
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
||||
|
||||
return get(newArrayType, elementData);
|
||||
return getRaw(newArrayType, elementData);
|
||||
}
|
||||
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
|
@ -829,7 +851,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
|
|||
auto newArrayType =
|
||||
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
||||
|
||||
return get(newArrayType, elementData);
|
||||
return getRaw(newArrayType, elementData);
|
||||
}
|
||||
|
||||
/// Iterator access to the float element values.
|
||||
|
|
|
@ -188,11 +188,6 @@ ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
|
|||
return SplatElementsAttr::get(type, elt);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<char> data) {
|
||||
return DenseElementsAttr::get(type, data);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
return DenseElementsAttr::get(type, values);
|
||||
|
|
Loading…
Reference in New Issue