Added the ability to run a mapping function across the values of an elements

attr. This supports both the SplatElementsAttr and DenseElementsAttr for both
    float and integer inputs / outputs.

--

PiperOrigin-RevId: 249538085
This commit is contained in:
Rob Suderman 2019-05-22 15:55:17 -07:00 committed by Mehdi Amini
parent 5953d12b95
commit ca9bd9d3af
2 changed files with 221 additions and 0 deletions

View File

@ -405,6 +405,20 @@ public:
/// element, then a null attribute is returned.
Attribute getValue(ArrayRef<uint64_t> index) const;
/// Generates a new ElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain integers.
ElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Generates a new ElementsAttr by mapping each float value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain floats.
ElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
@ -424,6 +438,20 @@ public:
static SplatElementsAttr get(ShapedType type, Attribute elt);
Attribute getValue() const;
/// Generates a new SplatElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain integers.
SplatElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Generates a new SplatElementsAttr by mapping each float value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain floats.
SplatElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::SplatElements;
@ -454,6 +482,20 @@ public:
void getValues(SmallVectorImpl<Attribute> &values) const;
/// Generates a new DenseElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This underlying type must be an DenseIntElementsAttr.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Generates a new DenseElementsAttr by mapping each float value to a new
/// underlying APInt. the new values can represent either a integer or float.
/// This underlying type must be an DenseFPElementsAttr.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
ArrayRef<char> getRawData() const;
/// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
@ -550,6 +592,12 @@ public:
/// type of 'type'.
static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Gets the integer value of each of the dense elements.
void getValues(SmallVectorImpl<APInt> &values) const;
@ -593,6 +641,12 @@ public:
/// Gets the float value of each of the dense elements.
void getValues(SmallVectorImpl<APFloat> &values) const;
/// Generates a new DenseElementsAttr by mapping each value attribute, and
/// constructing the DenseElementsAttr given the new element type.
DenseElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Iterator access to the float element values.
iterator begin() const;
iterator end() const;

View File

@ -287,6 +287,36 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
}
}
ElementsAttr ElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
switch (getKind()) {
case StandardAttributes::DenseIntElements:
case StandardAttributes::DenseFPElements:
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
case StandardAttributes::SplatElements:
return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
default:
llvm_unreachable("unsupported ElementsAttr subtype");
}
}
ElementsAttr ElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const {
switch (getKind()) {
case StandardAttributes::DenseIntElements:
case StandardAttributes::DenseFPElements:
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
case StandardAttributes::SplatElements:
return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
default:
break;
}
llvm_unreachable("unsupported ElementsAttr subtype");
}
//===----------------------------------------------------------------------===//
// SplatElementsAttr
//===----------------------------------------------------------------------===//
@ -300,6 +330,74 @@ SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; }
SplatElementsAttr SplatElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
Type inType = getType();
auto inElementType = getType().getElementType();
ShapedType newArrayType;
if (inType.isa<RankedTensorType>())
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
else if (inType.isa<UnrankedTensorType>())
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
else
assert(false && "Unhandled tensor type");
assert(inElementType.isa<IntegerType>() &&
"Attempting to map non-integer array as integers");
if (newElementType.isa<IntegerType>()) {
APInt newValue = mapping(getValue().cast<IntegerAttr>().getValue());
auto newAttr = IntegerAttr::get(newElementType, newValue);
return get(newArrayType, newAttr);
}
if (newElementType.isa<FloatType>()) {
APFloat newValue(newElementType.cast<FloatType>().getFloatSemantics(),
mapping(getValue().cast<IntegerAttr>().getValue()));
auto newAttr = FloatAttr::get(newElementType, newValue);
return get(newArrayType, newAttr);
}
llvm_unreachable("unknown output splat type");
}
SplatElementsAttr SplatElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const {
Type inType = getType();
auto inElementType = getType().getElementType();
ShapedType newArrayType;
if (inType.isa<RankedTensorType>()) {
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
} else if (inType.isa<UnrankedTensorType>()) {
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
}
assert(newArrayType && "Unhandled tensor type");
assert(inElementType.isa<FloatType>() &&
"mapping function expects float tensor");
Attribute newAttr;
if (newElementType.isa<IntegerType>()) {
APInt newValue = mapping(getValue().cast<FloatAttr>().getValue());
newAttr = IntegerAttr::get(newElementType, newValue);
return get(newArrayType, newAttr);
}
if (newElementType.isa<FloatType>()) {
APFloat newValue =
APFloat(newElementType.cast<FloatType>().getFloatSemantics(),
mapping(getValue().cast<FloatAttr>().getValue()));
newAttr = FloatAttr::get(newElementType, newValue);
return get(newArrayType, newAttr);
}
llvm_unreachable("unknown output splat type");
}
//===----------------------------------------------------------------------===//
// RawElementIterator
//===----------------------------------------------------------------------===//
@ -459,6 +557,18 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
}
}
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
}
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const {
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
}
ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(impl)->data;
}
@ -562,6 +672,35 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
values.assign(raw_begin(), raw_end());
}
DenseElementsAttr DenseIntElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
Type inType = getType();
size_t bitWidth = getDenseElementBitwidth(newElementType);
ShapedType newArrayType;
if (inType.isa<RankedTensorType>()) {
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
} else if (inType.isa<UnrankedTensorType>()) {
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
}
assert(newArrayType && "Unhandled tensor type");
llvm::SmallVector<char, 8> elementData(APInt::getNumWords(bitWidth * size()) *
APInt::APINT_WORD_SIZE);
uint64_t elementIdx = 0;
for (auto value : *this) {
auto newInt = mapping(value);
assert(newInt.getBitWidth() == bitWidth);
writeBits(elementData.data(), elementIdx * bitWidth, newInt);
++elementIdx;
}
return get(newArrayType, elementData);
}
//===----------------------------------------------------------------------===//
// DenseFPElementsAttr
//===----------------------------------------------------------------------===//
@ -589,6 +728,34 @@ void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
values.assign(begin(), end());
}
DenseElementsAttr DenseFPElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const {
Type inType = getType();
size_t bitWidth = getDenseElementBitwidth(newElementType);
ShapedType newArrayType;
if (inType.isa<RankedTensorType>())
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
else if (inType.isa<UnrankedTensorType>())
newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
else
assert(false && "Unhandled tensor type");
llvm::SmallVector<char, 80> elementData(
APInt::getNumWords(bitWidth * size()) * APInt::APINT_WORD_SIZE);
uint64_t elementIdx = 0;
for (auto value : *this) {
auto newInt = mapping(value);
assert(newInt.getBitWidth() == bitWidth);
writeBits(elementData.data(), elementIdx * bitWidth, newInt);
++elementIdx;
}
return get(newArrayType, elementData);
}
/// Iterator access to the float element values.
DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
auto elementType = getType().getElementType().cast<FloatType>();