diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 3c45b71e9952..52c5e4eb4953 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -326,6 +326,10 @@ public: VectorOrTensorType getType() const; + /// Return the value at the given index. If index does not refer to a valid + /// element, then a null attribute is returned. + Attribute getValue(ArrayRef index) const; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool kindof(Kind kind) { return kind >= Kind::FIRST_ELEMENTS_ATTR && diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 86413ad23ecb..7168acf09a6b 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -164,6 +164,24 @@ VectorOrTensorType ElementsAttr::getType() const { return static_cast(attr)->type; } +/// Return the value at the given index. If index does not refer to a valid +/// element, then a null attribute is returned. +Attribute ElementsAttr::getValue(ArrayRef index) const { + switch (getKind()) { + case Attribute::Kind::SplatElements: + return cast().getValue(); + case Attribute::Kind::DenseFPElements: + case Attribute::Kind::DenseIntElements: + return cast().getValue(index); + case Attribute::Kind::OpaqueElements: + return cast().getValue(index); + case Attribute::Kind::SparseElements: + return cast().getValue(index); + default: + llvm_unreachable("unknown ElementsAttr kind"); + } +} + /// SplatElementsAttr Attribute SplatElementsAttr::getValue() const { diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 6ae1e6b75a69..45353e80a8ec 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -1180,19 +1180,10 @@ Attribute ExtractElementOp::constantFold(ArrayRef operands, indices.push_back(indice.cast().getInt()); } - // Get the element value of the aggregate attribute with the given constant - // indices. - switch (aggregate.getKind()) { - case Attribute::Kind::DenseFPElements: - case Attribute::Kind::DenseIntElements: - return aggregate.cast().getValue(indices); - case Attribute::Kind::OpaqueElements: - return aggregate.cast().getValue(indices); - case Attribute::Kind::SparseElements: - return aggregate.cast().getValue(indices); - default: - return Attribute(); - } + // If this is an elements attribute, query the value at the given indices. + if (auto elementsAttr = aggregate.dyn_cast()) + return elementsAttr.getValue(indices); + return Attribute(); } //===----------------------------------------------------------------------===//