From a51d21538c01e3e8d1c9f233f29a1c4b7f0d6b1b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 25 Feb 2019 08:21:41 -0800 Subject: [PATCH] Add constant folding for ExtractElementOp when the aggregate is an OpaqueElementsAttr. PiperOrigin-RevId: 235533283 --- mlir/include/mlir/IR/Attributes.h | 4 ++++ mlir/lib/IR/Attributes.cpp | 8 ++++++++ mlir/lib/StandardOps/StandardOps.cpp | 2 ++ 3 files changed, 14 insertions(+) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index e961e6e1f5b4..5d083f834f19 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -438,6 +438,10 @@ public: StringRef getValue() const; + /// Return the value at the given index. If index does not refer to a valid + /// element, then a null attribute is returned. + Attribute getValue(ArrayRef index) const; + /// Decodes the attribute value using dialect-specific decoding hook. /// Returns false if decoding is successful. If not, returns true and leaves /// 'result' argument unspecified. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 5b2b02c1c556..7a4a52e809eb 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -342,6 +342,14 @@ StringRef OpaqueElementsAttr::getValue() const { return static_cast(attr)->bytes; } +/// Return the value at the given index. If index does not refer to a valid +/// element, then a null attribute is returned. +Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { + if (Dialect *dialect = getDialect()) + return dialect->extractElementHook(*this, index); + return Attribute(); +} + Dialect *OpaqueElementsAttr::getDialect() const { return static_cast(attr)->dialect; } diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index ccdee04fbcd5..6ae1e6b75a69 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -1186,6 +1186,8 @@ Attribute ExtractElementOp::constantFold(ArrayRef operands, case Attribute::Kind::DenseFPElements: case Attribute::Kind::DenseIntElements: return aggregate.cast().getValue(indices); + case Attribute::Kind::OpaqueElements: + return aggregate.cast().getValue(indices); case Attribute::Kind::SparseElements: return aggregate.cast().getValue(indices); default: