forked from OSchip/llvm-project
Remove the ability to directly construct a DenseElementsAttr with a raw character buffer. This made assumptions about how DenseElementsAttr structured its internal storage, which may change in the future. To replace the existing use cases, a few utility methods have been added:
* 'get' methods that allow constructing from an ArrayRef of integer or floating point values. * A 'reshape' method to allow for changing the shape without changing the underlying data. PiperOrigin-RevId: 252067898
This commit is contained in:
parent
62facfaf42
commit
0cadec8ae6
|
@ -81,8 +81,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
||||||
// auto oldType = constantOp.getType();
|
// auto oldType = constantOp.getType();
|
||||||
auto newType = rewriter.getTensorType(
|
auto newType = rewriter.getTensorType(
|
||||||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||||
auto newAttr =
|
auto newAttr = valueAttr.reshape(newType);
|
||||||
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
|
|
||||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||||
newAttr);
|
newAttr);
|
||||||
} else if (auto valueAttr =
|
} else if (auto valueAttr =
|
||||||
|
|
|
@ -83,8 +83,7 @@ struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
||||||
// auto oldType = constantOp.getType();
|
// auto oldType = constantOp.getType();
|
||||||
auto newType = rewriter.getTensorType(
|
auto newType = rewriter.getTensorType(
|
||||||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||||
auto newAttr =
|
auto newAttr = valueAttr.reshape(newType);
|
||||||
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
|
|
||||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||||
newAttr);
|
newAttr);
|
||||||
} else if (auto valueAttr =
|
} else if (auto valueAttr =
|
||||||
|
|
|
@ -502,16 +502,32 @@ class DenseElementsAttr
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
/// It assumes the elements in the input array have been truncated to the bits
|
|
||||||
/// width specified by the element type. 'type' must be a vector or tensor
|
|
||||||
/// with static shape.
|
|
||||||
static DenseElementsAttr get(ShapedType type, ArrayRef<char> data);
|
|
||||||
|
|
||||||
/// Constructs a dense elements attribute from an array of element values.
|
/// Constructs a dense elements attribute from an array of element values.
|
||||||
/// Each element attribute value is expected to be an element of 'type'.
|
/// Each element attribute value is expected to be an element of 'type'.
|
||||||
/// 'type' must be a vector or tensor with static shape.
|
/// 'type' must be a vector or tensor with static shape.
|
||||||
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
|
static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values);
|
||||||
|
|
||||||
|
/// Constructs a dense integer elements attribute from an array of integer
|
||||||
|
/// or floating-point values. Each value is expected to be the same bitwidth
|
||||||
|
/// of the element type of 'type'. 'type' must be a vector or tensor with
|
||||||
|
/// static shape.
|
||||||
|
template <typename ShapeT, typename T>
|
||||||
|
static DenseElementsAttr get(ShapeT type, ArrayRef<T> values) {
|
||||||
|
static_assert(std::numeric_limits<T>::is_integer ||
|
||||||
|
llvm::is_one_of<T, float, double>::value,
|
||||||
|
"expected integer or floating point element type");
|
||||||
|
|
||||||
|
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||||
|
assert(type.getElementTypeBitWidth() == (sizeof(T) * CHAR_BIT));
|
||||||
|
const char *data = reinterpret_cast<const char *>(values.data());
|
||||||
|
return getRawIntOrFloat(type,
|
||||||
|
ArrayRef<char>(data, values.size() * sizeof(T)),
|
||||||
|
/*isInt=*/std::numeric_limits<T>::is_integer);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Overload of the above 'get' method that is specialized for boolean values.
|
||||||
|
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
|
||||||
|
|
||||||
/// Returns the number of elements held by this attribute.
|
/// Returns the number of elements held by this attribute.
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
|
||||||
|
@ -519,8 +535,14 @@ public:
|
||||||
/// element, then a null attribute is returned.
|
/// element, then a null attribute is returned.
|
||||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||||
|
|
||||||
|
/// Return the held element values as Attributes in 'values'.
|
||||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||||
|
|
||||||
|
/// Return a new DenseElementsAttr that has the same data as the current
|
||||||
|
/// attribute, but has been reshaped to 'newType'. The new type must have the
|
||||||
|
/// same total number of elements as well as element type.
|
||||||
|
DenseElementsAttr reshape(ShapedType newType);
|
||||||
|
|
||||||
/// Generates a new DenseElementsAttr by mapping each int value to a new
|
/// Generates a new DenseElementsAttr by mapping each int value to a new
|
||||||
/// underlying APInt. The new values can represent either a integer or float.
|
/// underlying APInt. The new values can represent either a integer or float.
|
||||||
/// This underlying type must be an DenseIntElementsAttr.
|
/// This underlying type must be an DenseIntElementsAttr.
|
||||||
|
@ -600,6 +622,15 @@ protected:
|
||||||
return RawElementIterator(*this, size());
|
return RawElementIterator(*this, size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get or create a new dense elements attribute instance with the given raw
|
||||||
|
/// data buffer. 'type' must be a vector or tensor with static shape.
|
||||||
|
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data);
|
||||||
|
|
||||||
|
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||||
|
/// integer or floating-point type.
|
||||||
|
static DenseElementsAttr getRawIntOrFloat(ShapedType type,
|
||||||
|
ArrayRef<char> data, bool isInt);
|
||||||
|
|
||||||
/// Constructs a dense elements attribute from an array of raw APInt values.
|
/// Constructs a dense elements attribute from an array of raw APInt values.
|
||||||
/// Each APInt value is expected to have the same bitwidth as the element type
|
/// Each APInt value is expected to have the same bitwidth as the element type
|
||||||
/// of 'type'. 'type' must be a vector or tensor with static shape.
|
/// of 'type'. 'type' must be a vector or tensor with static shape.
|
||||||
|
@ -624,11 +655,6 @@ public:
|
||||||
/// shape.
|
/// shape.
|
||||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
static DenseIntElementsAttr get(ShapedType type, ArrayRef<APInt> values);
|
||||||
|
|
||||||
/// Constructs a dense integer elements attribute from an array of integer
|
|
||||||
/// values. Each value is expected to be within the bitwidth of the element
|
|
||||||
/// type of 'type'. 'type' must be a vector or tensor with static shape.
|
|
||||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
|
|
||||||
|
|
||||||
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
||||||
/// constructing the DenseElementsAttr given the new element type.
|
/// constructing the DenseElementsAttr given the new element type.
|
||||||
DenseElementsAttr
|
DenseElementsAttr
|
||||||
|
|
|
@ -116,7 +116,6 @@ public:
|
||||||
FunctionAttr getFunctionAttr(Function *value);
|
FunctionAttr getFunctionAttr(Function *value);
|
||||||
FunctionAttr getFunctionAttr(StringRef value);
|
FunctionAttr getFunctionAttr(StringRef value);
|
||||||
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
|
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
|
||||||
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
|
|
||||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||||
ArrayRef<Attribute> values);
|
ArrayRef<Attribute> values);
|
||||||
ElementsAttr getDenseIntElementsAttr(ShapedType type,
|
ElementsAttr getDenseIntElementsAttr(ShapedType type,
|
||||||
|
|
|
@ -532,7 +532,8 @@ APInt DenseElementsAttr::RawElementIterator::operator*() const {
|
||||||
// DenseElementsAttr
|
// DenseElementsAttr
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
|
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
||||||
|
ArrayRef<char> data) {
|
||||||
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
|
assert((static_cast<uint64_t>(type.getSizeInBits()) <=
|
||||||
data.size() * APInt::APINT_WORD_SIZE) &&
|
data.size() * APInt::APINT_WORD_SIZE) &&
|
||||||
"Input data bit size should be larger than that type requires");
|
"Input data bit size should be larger than that type requires");
|
||||||
|
@ -543,6 +544,27 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
|
||||||
data);
|
data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||||
|
/// integer type.
|
||||||
|
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
|
||||||
|
ArrayRef<char> data,
|
||||||
|
bool isInt) {
|
||||||
|
assert(isInt ? type.getElementType().isa<IntegerType>()
|
||||||
|
: type.getElementType().isa<FloatType>());
|
||||||
|
return getRaw(type, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
|
ArrayRef<bool> values) {
|
||||||
|
assert(type.getNumElements() == static_cast<int64_t>(values.size()));
|
||||||
|
assert(type.getElementType().isInteger(1));
|
||||||
|
|
||||||
|
std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
|
||||||
|
for (int i = 0, e = values.size(); i != e; ++i)
|
||||||
|
writeBits(buff.data(), i, llvm::APInt(1, values[i]));
|
||||||
|
return getRaw(type, buff);
|
||||||
|
}
|
||||||
|
|
||||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
ArrayRef<Attribute> values) {
|
ArrayRef<Attribute> values) {
|
||||||
assert(type.getElementType().isIntOrFloat() &&
|
assert(type.getElementType().isIntOrFloat() &&
|
||||||
|
@ -579,7 +601,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
"expected value to have same bitwidth as element type");
|
"expected value to have same bitwidth as element type");
|
||||||
writeBits(data.data(), i * storageBitWidth, intVal);
|
writeBits(data.data(), i * storageBitWidth, intVal);
|
||||||
}
|
}
|
||||||
return get(type, data);
|
return getRaw(type, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of elements held by this attribute.
|
/// Returns the number of elements held by this attribute.
|
||||||
|
@ -650,6 +672,22 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
llvm_unreachable("unexpected element type");
|
llvm_unreachable("unexpected element type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a new DenseElementsAttr that has the same data as the current
|
||||||
|
/// attribute, but has been reshaped to 'newType'. The new type must have the
|
||||||
|
/// same total number of elements as well as element type.
|
||||||
|
DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
|
||||||
|
ShapedType curType = getType();
|
||||||
|
if (curType == newType)
|
||||||
|
return *this;
|
||||||
|
|
||||||
|
(void)curType;
|
||||||
|
assert(newType.getElementType() == curType.getElementType() &&
|
||||||
|
"expected the same element type");
|
||||||
|
assert(newType.getNumElements() == curType.getNumElements() &&
|
||||||
|
"expected the same number of elements");
|
||||||
|
return getRaw(newType, getRawData());
|
||||||
|
}
|
||||||
|
|
||||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||||
Type newElementType,
|
Type newElementType,
|
||||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||||
|
@ -681,7 +719,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||||
assert(values[i].getBitWidth() == bitWidth);
|
assert(values[i].getBitWidth() == bitWidth);
|
||||||
writeBits(elementData.data(), i * storageBitWidth, values[i]);
|
writeBits(elementData.data(), i * storageBitWidth, values[i]);
|
||||||
}
|
}
|
||||||
return get(type, elementData);
|
return getRaw(type, elementData);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Writes value to the bit position `bitPos` in array `rawData`.
|
/// Writes value to the bit position `bitPos` in array `rawData`.
|
||||||
|
@ -728,22 +766,6 @@ DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
||||||
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
|
return DenseElementsAttr::get(type, values).cast<DenseIntElementsAttr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs a dense integer elements attribute from an array of integer
|
|
||||||
/// values. Each value is expected to be within the bitwidth of the element
|
|
||||||
/// type of 'type'.
|
|
||||||
DenseIntElementsAttr DenseIntElementsAttr::get(ShapedType type,
|
|
||||||
ArrayRef<int64_t> values) {
|
|
||||||
auto eltType = type.getElementType();
|
|
||||||
size_t bitWidth = eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
|
||||||
|
|
||||||
// Convert the raw integer values to APInt.
|
|
||||||
SmallVector<APInt, 8> apIntValues;
|
|
||||||
apIntValues.reserve(values.size());
|
|
||||||
for (auto value : values)
|
|
||||||
apIntValues.emplace_back(APInt(bitWidth, value));
|
|
||||||
return get(type, apIntValues);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
||||||
values.reserve(size());
|
values.reserve(size());
|
||||||
values.assign(raw_begin(), raw_end());
|
values.assign(raw_begin(), raw_end());
|
||||||
|
@ -786,7 +808,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
|
||||||
auto newArrayType =
|
auto newArrayType =
|
||||||
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
||||||
|
|
||||||
return get(newArrayType, elementData);
|
return getRaw(newArrayType, elementData);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||||
|
@ -829,7 +851,7 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
|
||||||
auto newArrayType =
|
auto newArrayType =
|
||||||
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
mappingHelper(mapping, *this, getType(), newElementType, elementData);
|
||||||
|
|
||||||
return get(newArrayType, elementData);
|
return getRaw(newArrayType, elementData);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Iterator access to the float element values.
|
/// Iterator access to the float element values.
|
||||||
|
|
|
@ -188,11 +188,6 @@ ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
|
||||||
return SplatElementsAttr::get(type, elt);
|
return SplatElementsAttr::get(type, elt);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
|
||||||
ArrayRef<char> data) {
|
|
||||||
return DenseElementsAttr::get(type, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||||
ArrayRef<Attribute> values) {
|
ArrayRef<Attribute> values) {
|
||||||
return DenseElementsAttr::get(type, values);
|
return DenseElementsAttr::get(type, values);
|
||||||
|
|
Loading…
Reference in New Issue