diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 39ca337c0bfe..106a9f5366da 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -405,6 +405,20 @@ public: /// element, then a null attribute is returned. Attribute getValue(ArrayRef index) const; + /// Generates a new ElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain integers. + ElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + + /// Generates a new ElementsAttr by mapping each float value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain floats. + ElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR && @@ -424,6 +438,20 @@ public: static SplatElementsAttr get(ShapedType type, Attribute elt); Attribute getValue() const; + /// Generates a new SplatElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain integers. + SplatElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + + /// Generates a new SplatElementsAttr by mapping each float value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This ElementsAttr should contain floats. + SplatElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool kindof(unsigned kind) { return kind == StandardAttributes::SplatElements; @@ -454,6 +482,20 @@ public: void getValues(SmallVectorImpl &values) const; + /// Generates a new DenseElementsAttr by mapping each int value to a new + /// underlying APInt. The new values can represent either a integer or float. + /// This underlying type must be an DenseIntElementsAttr. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + + /// Generates a new DenseElementsAttr by mapping each float value to a new + /// underlying APInt. the new values can represent either a integer or float. + /// This underlying type must be an DenseFPElementsAttr. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + ArrayRef getRawData() const; /// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is @@ -550,6 +592,12 @@ public: /// type of 'type'. static DenseIntElementsAttr get(ShapedType type, ArrayRef values); + /// Generates a new DenseElementsAttr by mapping each value attribute, and + /// constructing the DenseElementsAttr given the new element type. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Gets the integer value of each of the dense elements. void getValues(SmallVectorImpl &values) const; @@ -593,6 +641,12 @@ public: /// Gets the float value of each of the dense elements. void getValues(SmallVectorImpl &values) const; + /// Generates a new DenseElementsAttr by mapping each value attribute, and + /// constructing the DenseElementsAttr given the new element type. + DenseElementsAttr + mapValues(Type newElementType, + llvm::function_ref mapping) const; + /// Iterator access to the float element values. iterator begin() const; iterator end() const; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index add8a854742c..ac1f180a9870 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -287,6 +287,34 @@ Attribute ElementsAttr::getValue(ArrayRef index) const { } } +ElementsAttr ElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + switch (getKind()) { + case StandardAttributes::DenseIntElements: + case StandardAttributes::DenseFPElements: + return cast().mapValues(newElementType, mapping); + case StandardAttributes::SplatElements: + return cast().mapValues(newElementType, mapping); + default: + llvm_unreachable("unsupported ElementsAttr subtype"); + } +} + +ElementsAttr ElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + switch (getKind()) { + case StandardAttributes::DenseIntElements: + case StandardAttributes::DenseFPElements: + return cast().mapValues(newElementType, mapping); + case StandardAttributes::SplatElements: + return cast().mapValues(newElementType, mapping); + default: + llvm_unreachable("unsupported ElementsAttr subtype"); + } +} + //===----------------------------------------------------------------------===// // SplatElementsAttr //===----------------------------------------------------------------------===// @@ -300,6 +328,73 @@ SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; } +SplatElementsAttr SplatElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + ShapedType inType = getType(); + + ShapedType newArrayType; + if (inType.isa()) + newArrayType = RankedTensorType::get(inType.getShape(), newElementType); + else if (inType.isa()) + newArrayType = RankedTensorType::get(inType.getShape(), newElementType); + else if (inType.isa()) + newArrayType = VectorType::get(inType.getShape(), newElementType); + else + assert(false && "Unhandled tensor type"); + + assert(getType().getElementType().isa() && + "Attempting to map non-integer array as integers"); + + if (newElementType.isa()) { + APInt newValue = mapping(getValue().cast().getValue()); + auto newAttr = IntegerAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + if (newElementType.isa()) { + APFloat newValue(newElementType.cast().getFloatSemantics(), + mapping(getValue().cast().getValue())); + auto newAttr = FloatAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + llvm_unreachable("unknown output splat type"); +} + +SplatElementsAttr SplatElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + Type inType = getType(); + + ShapedType newArrayType; + if (inType.isa()) { + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + } else if (inType.isa()) { + newArrayType = RankedTensorType::get(getType().getShape(), newElementType); + } + + assert(newArrayType && "Unhandled tensor type"); + assert(getType().getElementType().isa() && + "mapping function expects float tensor"); + + Attribute newAttr; + if (newElementType.isa()) { + APInt newValue = mapping(getValue().cast().getValue()); + newAttr = IntegerAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + if (newElementType.isa()) { + APFloat newValue(newElementType.cast().getFloatSemantics(), + mapping(getValue().cast().getValue())); + newAttr = FloatAttr::get(newElementType, newValue); + return get(newArrayType, newAttr); + } + + llvm_unreachable("unknown output splat type"); +} + //===----------------------------------------------------------------------===// // RawElementIterator //===----------------------------------------------------------------------===// @@ -459,6 +554,18 @@ void DenseElementsAttr::getValues(SmallVectorImpl &values) const { } } +DenseElementsAttr DenseElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + return cast().mapValues(newElementType, mapping); +} + +DenseElementsAttr DenseElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + return cast().mapValues(newElementType, mapping); +} + ArrayRef DenseElementsAttr::getRawData() const { return static_cast(impl)->data; } @@ -562,6 +669,46 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl &values) const { values.assign(raw_begin(), raw_end()); } +template +static ShapedType mappingHelper( + Fn mapping, Attr& attr, ShapedType inType, Type newElementType, + llvm::SmallVectorImpl& data) { + size_t bitWidth = getDenseElementBitwidth(newElementType); + + ShapedType newArrayType; + if (inType.isa()) + newArrayType = RankedTensorType::get(inType.getShape(), newElementType); + else if (inType.isa()) + newArrayType = RankedTensorType::get(inType.getShape(), newElementType); + else if (inType.isa()) + newArrayType = VectorType::get(inType.getShape(), newElementType); + else + assert(newArrayType && "Unhandled tensor type"); + + data.resize(APInt::getNumWords(bitWidth * inType.getNumElements()) * + APInt::APINT_WORD_SIZE); + + uint64_t elementIdx = 0; + for (auto value : attr) { + auto newInt = mapping(value); + assert(newInt.getBitWidth() == bitWidth); + attr.writeBits(data.data(), elementIdx * bitWidth, newInt); + ++elementIdx; + } + + return newArrayType; +} + +DenseElementsAttr DenseIntElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + llvm::SmallVector elementData; + auto newArrayType = mappingHelper( + mapping, *this, getType(), newElementType, elementData); + + return get(newArrayType, elementData); +} + //===----------------------------------------------------------------------===// // DenseFPElementsAttr //===----------------------------------------------------------------------===// @@ -589,6 +736,16 @@ void DenseFPElementsAttr::getValues(SmallVectorImpl &values) const { values.assign(begin(), end()); } +DenseElementsAttr DenseFPElementsAttr::mapValues( + Type newElementType, + llvm::function_ref mapping) const { + llvm::SmallVector elementData; + auto newArrayType = mappingHelper( + mapping, *this, getType(), newElementType, elementData); + + return get(newArrayType, elementData); +} + /// Iterator access to the float element values. DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const { auto elementType = getType().getElementType().cast();