From d36a15de1ff4d24e772233406d602c5f0b370f54 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident@gmail.com>
Date: Fri, 26 Feb 2021 13:11:02 -0800
Subject: [PATCH] [mlir][linalg] Memoize indexing map generation.

Differential Revision: https://reviews.llvm.org/D97602
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp                 | 7 +++++--
 .../mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp     | 9 ++++++++-
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index acc8ff1807c1..46e5780e151f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2333,8 +2333,11 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
   p << op.getOperationName();
-  p.printOptionalAttrDict(op->getAttrs(),
-                          /*elidedAttrs=*/{"operand_segment_sizes"});
+  p.printOptionalAttrDict(
+      op->getAttrs(),
+      /*elidedAttrs=*/{"operand_segment_sizes",
+                       // See generated code in mlir-linalg-yaml-gen.cpp
+                       "linalg.memoized_indexing_maps"});
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 5578ff52d477..1dddc57f25d3 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -651,11 +651,18 @@ static SmallVector<AffineExpr> getSymbolBindings({0} self) {
       // {2}: Statements
       static const char structuredOpIndexingMapsFormat[] = R"FMT(
 ArrayAttr {0}::indexing_maps() {
+  static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
+  if (cached)
+    return cached;
+
   MLIRContext *context = getContext();
   auto symbolBindings = getSymbolBindings(*this);
   SmallVector<AffineMap> maps;
   {2}
-  return Builder(context).getAffineMapArrayAttr(maps);
+  cached = Builder(context).getAffineMapArrayAttr(maps);
+  getOperation()->setAttr(memoizeAttr, cached);
+  return cached;
 }
 )FMT";