forked from OSchip/llvm-project
Added the ability to run a mapping function across the values of an elements
attr. This supports both the SplatElementsAttr and DenseElementsAttr for both float and integer inputs / outputs. -- PiperOrigin-RevId: 249681056
This commit is contained in:
parent
4958ec2414
commit
e2b715fe41
|
@ -405,6 +405,20 @@ public:
|
|||
/// element, then a null attribute is returned.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
||||
/// Generates a new ElementsAttr by mapping each int value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This ElementsAttr should contain integers.
|
||||
ElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
||||
|
||||
/// Generates a new ElementsAttr by mapping each float value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This ElementsAttr should contain floats.
|
||||
ElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
|
||||
|
@ -424,6 +438,20 @@ public:
|
|||
static SplatElementsAttr get(ShapedType type, Attribute elt);
|
||||
Attribute getValue() const;
|
||||
|
||||
/// Generates a new SplatElementsAttr by mapping each int value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This ElementsAttr should contain integers.
|
||||
SplatElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
||||
|
||||
/// Generates a new SplatElementsAttr by mapping each float value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This ElementsAttr should contain floats.
|
||||
SplatElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::SplatElements;
|
||||
|
@ -454,6 +482,20 @@ public:
|
|||
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each int value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This underlying type must be an DenseIntElementsAttr.
|
||||
DenseElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each float value to a new
|
||||
/// underlying APInt. the new values can represent either a integer or float.
|
||||
/// This underlying type must be an DenseFPElementsAttr.
|
||||
DenseElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
ArrayRef<char> getRawData() const;
|
||||
|
||||
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
|
||||
|
@ -550,6 +592,12 @@ public:
|
|||
/// type of 'type'.
|
||||
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
||||
/// constructing the DenseElementsAttr given the new element type.
|
||||
DenseElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
||||
|
||||
/// Gets the integer value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APInt> &values) const;
|
||||
|
||||
|
@ -593,6 +641,12 @@ public:
|
|||
/// Gets the float value of each of the dense elements.
|
||||
void getValues(SmallVectorImpl<APFloat> &values) const;
|
||||
|
||||
/// Generates a new DenseElementsAttr by mapping each value attribute, and
|
||||
/// constructing the DenseElementsAttr given the new element type.
|
||||
DenseElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
/// Iterator access to the float element values.
|
||||
iterator begin() const;
|
||||
iterator end() const;
|
||||
|
|
|
@ -287,6 +287,34 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
}
|
||||
}
|
||||
|
||||
ElementsAttr ElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntElements:
|
||||
case StandardAttributes::DenseFPElements:
|
||||
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
||||
case StandardAttributes::SplatElements:
|
||||
return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
|
||||
default:
|
||||
llvm_unreachable("unsupported ElementsAttr subtype");
|
||||
}
|
||||
}
|
||||
|
||||
ElementsAttr ElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntElements:
|
||||
case StandardAttributes::DenseFPElements:
|
||||
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
||||
case StandardAttributes::SplatElements:
|
||||
return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
|
||||
default:
|
||||
llvm_unreachable("unsupported ElementsAttr subtype");
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -300,6 +328,73 @@ SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
|
|||
|
||||
Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; }
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
ShapedType inType = getType();
|
||||
|
||||
ShapedType newArrayType;
|
||||
if (inType.isa<RankedTensorType>())
|
||||
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
|
||||
else if (inType.isa<UnrankedTensorType>())
|
||||
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
|
||||
else if (inType.isa<VectorType>())
|
||||
newArrayType = VectorType::get(inType.getShape(), newElementType);
|
||||
else
|
||||
assert(false && "Unhandled tensor type");
|
||||
|
||||
assert(getType().getElementType().isa<IntegerType>() &&
|
||||
"Attempting to map non-integer array as integers");
|
||||
|
||||
if (newElementType.isa<IntegerType>()) {
|
||||
APInt newValue = mapping(getValue().cast<IntegerAttr>().getValue());
|
||||
auto newAttr = IntegerAttr::get(newElementType, newValue);
|
||||
return get(newArrayType, newAttr);
|
||||
}
|
||||
|
||||
if (newElementType.isa<FloatType>()) {
|
||||
APFloat newValue(newElementType.cast<FloatType>().getFloatSemantics(),
|
||||
mapping(getValue().cast<IntegerAttr>().getValue()));
|
||||
auto newAttr = FloatAttr::get(newElementType, newValue);
|
||||
return get(newArrayType, newAttr);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown output splat type");
|
||||
}
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
||||
Type inType = getType();
|
||||
|
||||
ShapedType newArrayType;
|
||||
if (inType.isa<RankedTensorType>()) {
|
||||
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
|
||||
} else if (inType.isa<UnrankedTensorType>()) {
|
||||
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
|
||||
}
|
||||
|
||||
assert(newArrayType && "Unhandled tensor type");
|
||||
assert(getType().getElementType().isa<FloatType>() &&
|
||||
"mapping function expects float tensor");
|
||||
|
||||
Attribute newAttr;
|
||||
if (newElementType.isa<IntegerType>()) {
|
||||
APInt newValue = mapping(getValue().cast<FloatAttr>().getValue());
|
||||
newAttr = IntegerAttr::get(newElementType, newValue);
|
||||
return get(newArrayType, newAttr);
|
||||
}
|
||||
|
||||
if (newElementType.isa<FloatType>()) {
|
||||
APFloat newValue(newElementType.cast<FloatType>().getFloatSemantics(),
|
||||
mapping(getValue().cast<FloatAttr>().getValue()));
|
||||
newAttr = FloatAttr::get(newElementType, newValue);
|
||||
return get(newArrayType, newAttr);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown output splat type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RawElementIterator
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -459,6 +554,18 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|||
}
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
||||
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
|
||||
}
|
||||
|
||||
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
||||
return static_cast<ImplType *>(impl)->data;
|
||||
}
|
||||
|
@ -562,6 +669,46 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
|
|||
values.assign(raw_begin(), raw_end());
|
||||
}
|
||||
|
||||
template<typename Fn, typename Attr>
|
||||
static ShapedType mappingHelper(
|
||||
Fn mapping, Attr& attr, ShapedType inType, Type newElementType,
|
||||
llvm::SmallVectorImpl<char>& data) {
|
||||
size_t bitWidth = getDenseElementBitwidth(newElementType);
|
||||
|
||||
ShapedType newArrayType;
|
||||
if (inType.isa<RankedTensorType>())
|
||||
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
|
||||
else if (inType.isa<UnrankedTensorType>())
|
||||
newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
|
||||
else if (inType.isa<VectorType>())
|
||||
newArrayType = VectorType::get(inType.getShape(), newElementType);
|
||||
else
|
||||
assert(newArrayType && "Unhandled tensor type");
|
||||
|
||||
data.resize(APInt::getNumWords(bitWidth * inType.getNumElements()) *
|
||||
APInt::APINT_WORD_SIZE);
|
||||
|
||||
uint64_t elementIdx = 0;
|
||||
for (auto value : attr) {
|
||||
auto newInt = mapping(value);
|
||||
assert(newInt.getBitWidth() == bitWidth);
|
||||
attr.writeBits(data.data(), elementIdx * bitWidth, newInt);
|
||||
++elementIdx;
|
||||
}
|
||||
|
||||
return newArrayType;
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseIntElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
llvm::SmallVector<char, 8> elementData;
|
||||
auto newArrayType = mappingHelper(
|
||||
mapping, *this, getType(), newElementType, elementData);
|
||||
|
||||
return get(newArrayType, elementData);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -589,6 +736,16 @@ void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
|
|||
values.assign(begin(), end());
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseFPElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
||||
llvm::SmallVector<char, 8> elementData;
|
||||
auto newArrayType = mappingHelper(
|
||||
mapping, *this, getType(), newElementType, elementData);
|
||||
|
||||
return get(newArrayType, elementData);
|
||||
}
|
||||
|
||||
/// Iterator access to the float element values.
|
||||
DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
|
||||
auto elementType = getType().getElementType().cast<FloatType>();
|
||||
|
|
Loading…
Reference in New Issue