From e158b5634aa67ea3039a62c3d8bda79b77b3b21c Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Mon, 27 Sep 2021 10:07:44 +0000 Subject: [PATCH] [mlir][linalg] Make fusion on tensor rewriter friendly (NFC). Let the calling pass or pattern replace the uses of the original root operation. Internally, the tileAndFuse still replaces uses and updates operands but only of newly created operations. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110169 --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 4 +++- .../Linalg/Transforms/FusionOnTensors.cpp | 22 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 8d01b333e311..24c5784f2a9d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -199,9 +199,11 @@ public: /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the /// fused producer of fails if fusion is not possible. - // TODO: add replace uses callback to support passes and patterns. FailureOr fuseProducer(OpBuilder &b, OpOperand *rootOpOperand); + /// Returns the replacement results for the original untiled root operation. + ValueRange getRootOpReplacementResults(); + /// Returns the tiled root operation. LinalgOp getRootOp() { return rootOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 31e53f7cf93d..448e677e7ac1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -245,10 +245,15 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b, .setLoopType(LinalgTilingLoopType::Loops); Optional tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions); - // Replace all uses of the root operation. + // Exit if tiling the root operation fails. if (!tiledRootOp.hasValue()) return failure(); - rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); + + // Replace all uses of the root operation if it has been tiled before. All + // uses of the original untiled root operation are updated by the calling pass + // or pattern. + if (!isEmpty()) + rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); // Update the root operation and append the loops and tile loop dimensions. rootOp = tiledRootOp->op; @@ -323,6 +328,11 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, return clonedOp; } +ValueRange TileLoopNest::getRootOpReplacementResults() { + assert(!isEmpty() && "expect tile loop nest to be non-empty"); + return loopOps.front()->getOpResults(); +} + //===----------------------------------------------------------------------===// // Tile and fuse entry-points. //===----------------------------------------------------------------------===// @@ -433,9 +443,13 @@ struct LinalgTileAndFuseTensorOps "expect the tile interchange permutes the root loops"); // Tile `rootOp` and fuse its producers. - if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, - rootInterchange))) + FailureOr tileLoopNest = + tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange); + if (failed(tileLoopNest)) return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); + + // Replace all uses of the tiled loop operation. + rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); } }; } // namespace