diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 0cffdc4d9315..215dff12d34d 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -547,6 +547,22 @@ public: // Iterators //===--------------------------------------------------------------------===// + /// A utility iterator that allows walking over the internal Attribute values + /// of a DenseElementsAttr. + class AttributeElementIterator + : public indexed_accessor_iterator { + public: + /// Accesses the Attribute value at this iterator position. + Attribute operator*() const; + + private: + friend DenseElementsAttr; + + /// Constructs a new iterator. + AttributeElementIterator(DenseElementsAttr attr, size_t index); + }; + /// A utility iterator that allows walking over the internal raw APInt values. class IntElementIterator : public indexed_accessor_iterator index) const; - /// Return the held element values as Attributes in 'values'. - void getValues(SmallVectorImpl &values) const; - /// Return the held element values as an array of integer or floating-point /// values. template getAttributeValues() const; + template ::value>::type> + llvm::iterator_range getValues() const { + return getAttributeValues(); + } + AttributeElementIterator attr_value_begin() const; + AttributeElementIterator attr_value_end() const; + /// Return the held element values as a range of APInts. The element type of /// this attribute must be of integer type. llvm::iterator_range getIntValues() const; @@ -621,6 +644,8 @@ public: llvm::iterator_range getValues() const { return getIntValues(); } + IntElementIterator int_value_begin() const; + IntElementIterator int_value_end() const; /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. @@ -630,6 +655,8 @@ public: llvm::iterator_range getValues() const { return getFloatValues(); } + FloatElementIterator float_value_begin() const; + FloatElementIterator float_value_end() const; //===--------------------------------------------------------------------===// // Mutation Utilities @@ -704,8 +731,8 @@ public: llvm::function_ref mapping) const; /// Iterator access to the float element values. - iterator begin() const { return getFloatValues().begin(); } - iterator end() const { return getFloatValues().end(); } + iterator begin() const { return float_value_begin(); } + iterator end() const { return float_value_end(); } /// Method for supporting type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr); diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index c8ebb27ebb04..37ed96b1c278 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -460,6 +460,31 @@ static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { (type.getNumElements() == static_cast(values.size())); } +//===----------------------------------------------------------------------===// +// DenseElementAttr Iterators +//===----------------------------------------------------------------------===// + +/// Constructs a new iterator. +DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( + DenseElementsAttr attr, size_t index) + : indexed_accessor_iterator( + attr.getAsOpaquePointer(), index) {} + +/// Accesses the Attribute value at this iterator position. +Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { + auto owner = getFromOpaquePointer(object).cast(); + Type eltTy = owner.getType().getElementType(); + if (eltTy.isa()) + return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); + if (auto floatEltTy = eltTy.dyn_cast()) { + IntElementIterator intIt(owner, index); + FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); + return FloatAttr::get(eltTy, *floatIt); + } + llvm_unreachable("unexpected element type"); +} + /// Constructs a new iterator. DenseElementsAttr::IntElementIterator::IntElementIterator( DenseElementsAttr attr, size_t index) @@ -636,16 +661,7 @@ bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } /// If this attribute corresponds to a splat, then get the splat value. /// Otherwise, return null. Attribute DenseElementsAttr::getSplatValue() const { - if (!isSplat()) - return Attribute(); - - auto elementType = getType().getElementType(); - if (elementType.isa()) - return IntegerAttr::get(elementType, *raw_int_begin()); - if (auto fType = elementType.dyn_cast()) - return FloatAttr::get(elementType, - APFloat(fType.getFloatSemantics(), *raw_int_begin())); - llvm_unreachable("unexpected element type"); + return isSplat() ? *attr_value_begin() : Attribute(); } /// Return the value at the given index. If index does not refer to a valid @@ -692,23 +708,16 @@ Attribute DenseElementsAttr::getValue(ArrayRef index) const { llvm_unreachable("unexpected element type"); } -void DenseElementsAttr::getValues(SmallVectorImpl &values) const { - values.reserve(rawSize()); - - auto elementType = getType().getElementType(); - if (elementType.isa()) { - // Convert each value to an IntegerAttr. - for (auto intVal : getIntValues()) - values.push_back(IntegerAttr::get(elementType, intVal)); - return; - } - if (elementType.isa()) { - // Convert each value to a FloatAttr. - for (auto floatVal : getFloatValues()) - values.push_back(FloatAttr::get(elementType, floatVal)); - return; - } - llvm_unreachable("unexpected element type"); +/// Return the held element values as a range of Attributes. +auto DenseElementsAttr::getAttributeValues() const + -> llvm::iterator_range { + return {attr_value_begin(), attr_value_end()}; +} +auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { + return AttributeElementIterator(*this, 0); +} +auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { + return AttributeElementIterator(*this, rawSize()); } /// Return the held element values as a range of APInts. The element type of @@ -719,6 +728,16 @@ auto DenseElementsAttr::getIntValues() const "expected integer type"); return {raw_int_begin(), raw_int_end()}; } +auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { + assert(getType().getElementType().isa() && + "expected integer type"); + return raw_int_begin(); +} +auto DenseElementsAttr::int_value_end() const -> IntElementIterator { + assert(getType().getElementType().isa() && + "expected integer type"); + return raw_int_end(); +} /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. @@ -730,6 +749,12 @@ auto DenseElementsAttr::getFloatValues() const return {FloatElementIterator(elementSemantics, raw_int_begin()), FloatElementIterator(elementSemantics, raw_int_end())}; } +auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { + return getFloatValues().begin(); +} +auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { + return getFloatValues().end(); +} /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 1cb9e6c7dcb5..ef286cb64fdb 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -96,9 +96,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, SmallVector constants; uint64_t numElements = vectorType->getNumElements(); constants.reserve(numElements); - SmallVector nested; - denseAttr.getValues(nested); - for (auto n : nested) { + for (auto n : denseAttr.getAttributeValues()) { constants.push_back( getLLVMConstant(vectorType->getElementType(), n, loc)); if (!constants.back())