From ca9bd9d3af16c96fafb5d64fccea6c49c9dcb6eb Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 22 May 2019 15:55:17 -0700 Subject: [PATCH] Added the ability to run a mapping function across the values of an elements attr. This supports both the SplatElementsAttr and DenseElementsAttr for both float and integer inputs / outputs. -- PiperOrigin-RevId: 249538085 --- mlir/include/mlir/IR/Attributes.h | 54 ++++++++++ mlir/lib/IR/Attributes.cpp | 167 ++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 39ca337c0bfe..106a9f5366da 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -405,6 +405,20 @@ public: /// element, then a null attribute is returned. Attribute getValue(ArrayRef index) const; + /// Generates a new ElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain integers. + ElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + + /// Generates a new ElementsAttr by mapping each float value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain floats. + ElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR && @@ -424,6 +438,20 @@ public: static SplatElementsAttr get(ShapedType type, Attribute elt); Attribute getValue() const; + /// Generates a new SplatElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain integers. + SplatElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + + /// Generates a new SplatElementsAttr by mapping each float value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain floats. + SplatElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool kindof(unsigned kind) { return kind == StandardAttributes::SplatElements; @@ -454,6 +482,20 @@ public: void getValues(SmallVectorImpl &values) const; + /// Generates a new DenseElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This underlying type must be an DenseIntElementsAttr. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + + /// Generates a new DenseElementsAttr by mapping each float value to a new + /// underlying APInt. the new values can represent either a integer or float. + /// This underlying type must be an DenseFPElementsAttr. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + ArrayRef getRawData() const; /// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is @@ -550,6 +592,12 @@ public: /// type of 'type'. static DenseIntElementsAttr get(ShapedType type, ArrayRef values); + /// Generates a new DenseElementsAttr by mapping each value attribute, and + /// constructing the DenseElementsAttr given the new element type. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Gets the integer value of each of the dense elements. void getValues(SmallVectorImpl &values) const; @@ -593,6 +641,12 @@ public: /// Gets the float value of each of the dense elements. void getValues(SmallVectorImpl &values) const; + /// Generates a new DenseElementsAttr by mapping each value attribute, and + /// constructing the DenseElementsAttr given the new element type. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Iterator access to the float element values. iterator begin() const; iterator end() const; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index add8a854742c..9b0d744b6a10 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -287,6 +287,36 @@ Attribute ElementsAttr::getValue(ArrayRef index) const { } } +ElementsAttr ElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + switch (getKind()) { + case StandardAttributes::DenseIntElements: + case StandardAttributes::DenseFPElements: + return cast().mapValues(newElementType, mapping); + case StandardAttributes::SplatElements: + return cast().mapValues(newElementType, mapping); + default: + llvm_unreachable("unsupported ElementsAttr subtype"); + } +} + +ElementsAttr ElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + switch (getKind()) { + case StandardAttributes::DenseIntElements: + case StandardAttributes::DenseFPElements: + return cast().mapValues(newElementType, mapping); + case StandardAttributes::SplatElements: + return cast().mapValues(newElementType, mapping); + default: + break; + } + + llvm_unreachable("unsupported ElementsAttr subtype"); +} + //===----------------------------------------------------------------------===// // SplatElementsAttr //===----------------------------------------------------------------------===// @@ -300,6 +330,74 @@ SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; } +SplatElementsAttr SplatElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + Type inType = getType(); + auto inElementType = getType().getElementType(); + + ShapedType newArrayType; + if (inType.isa()) + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + else if (inType.isa()) + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + else + assert(false && "Unhandled tensor type"); + + assert(inElementType.isa() && + "Attempting to map non-integer array as integers"); + + if (newElementType.isa()) { + APInt newValue = mapping(getValue().cast().getValue()); + auto newAttr = IntegerAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + if (newElementType.isa()) { + APFloat newValue(newElementType.cast().getFloatSemantics(), + mapping(getValue().cast().getValue())); + auto newAttr = FloatAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + llvm_unreachable("unknown output splat type"); +} + +SplatElementsAttr SplatElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + Type inType = getType(); + auto inElementType = getType().getElementType(); + + ShapedType newArrayType; + if (inType.isa()) { + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + } else if (inType.isa()) { + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + } + + assert(newArrayType && "Unhandled tensor type"); + assert(inElementType.isa() && + "mapping function expects float tensor"); + + Attribute newAttr; + if (newElementType.isa()) { + APInt newValue = mapping(getValue().cast().getValue()); + newAttr = IntegerAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + if (newElementType.isa()) { + APFloat newValue = + APFloat(newElementType.cast().getFloatSemantics(), + mapping(getValue().cast().getValue())); + newAttr = FloatAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + llvm_unreachable("unknown output splat type"); +} + //===----------------------------------------------------------------------===// // RawElementIterator //===----------------------------------------------------------------------===// @@ -459,6 +557,18 @@ void DenseElementsAttr::getValues(SmallVectorImpl &values) const { } } +DenseElementsAttr DenseElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + return cast().mapValues(newElementType, mapping); +} + +DenseElementsAttr DenseElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + return cast().mapValues(newElementType, mapping); +} + ArrayRef DenseElementsAttr::getRawData() const { return static_cast(impl)->data; } @@ -562,6 +672,35 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl &values) const { values.assign(raw_begin(), raw_end()); } +DenseElementsAttr DenseIntElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + Type inType = getType(); + size_t bitWidth = getDenseElementBitwidth(newElementType); + + ShapedType newArrayType; + if (inType.isa()) { + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + } else if (inType.isa()) { + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + } + + assert(newArrayType && "Unhandled tensor type"); + + llvm::SmallVector elementData(APInt::getNumWords(bitWidth * size()) * + APInt::APINT_WORD_SIZE); + + uint64_t elementIdx = 0; + for (auto value : *this) { + auto newInt = mapping(value); + assert(newInt.getBitWidth() == bitWidth); + writeBits(elementData.data(), elementIdx * bitWidth, newInt); + ++elementIdx; + } + + return get(newArrayType, elementData); +} + //===----------------------------------------------------------------------===// // DenseFPElementsAttr //===----------------------------------------------------------------------===// @@ -589,6 +728,34 @@ void DenseFPElementsAttr::getValues(SmallVectorImpl &values) const { values.assign(begin(), end()); } +DenseElementsAttr DenseFPElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + Type inType = getType(); + size_t bitWidth = getDenseElementBitwidth(newElementType); + + ShapedType newArrayType; + if (inType.isa()) + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + else if (inType.isa()) + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + else + assert(false && "Unhandled tensor type"); + + llvm::SmallVector elementData( + APInt::getNumWords(bitWidth * size()) * APInt::APINT_WORD_SIZE); + + uint64_t elementIdx = 0; + for (auto value : *this) { + auto newInt = mapping(value); + assert(newInt.getBitWidth() == bitWidth); + writeBits(elementData.data(), elementIdx * bitWidth, newInt); + ++elementIdx; + } + + return get(newArrayType, elementData); +} + /// Iterator access to the float element values. DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const { auto elementType = getType().getElementType().cast();