[mlir][linalg] Make fusion on tensor rewriter friendly (NFC).

Let the calling pass or pattern replace the uses of the original root operation. Internally, the tileAndFuse still replaces uses and updates operands but only of newly created operations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110169
This commit is contained in:
Tobias Gysi 2021-09-27 10:07:44 +00:00
parent d5629b5d4d
commit e158b5634a
2 changed files with 21 additions and 5 deletions

View File

@ -199,9 +199,11 @@ public:
/// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
/// fused producer of fails if fusion is not possible.
// TODO: add replace uses callback to support passes and patterns.
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
/// Returns the replacement results for the original untiled root operation.
ValueRange getRootOpReplacementResults();
/// Returns the tiled root operation.
LinalgOp getRootOp() { return rootOp; }

View File

@ -245,10 +245,15 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
.setLoopType(LinalgTilingLoopType::Loops);
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
// Replace all uses of the root operation.
// Exit if tiling the root operation fails.
if (!tiledRootOp.hasValue())
return failure();
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
// Replace all uses of the root operation if it has been tiled before. All
// uses of the original untiled root operation are updated by the calling pass
// or pattern.
if (!isEmpty())
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
// Update the root operation and append the loops and tile loop dimensions.
rootOp = tiledRootOp->op;
@ -323,6 +328,11 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
return clonedOp;
}
ValueRange TileLoopNest::getRootOpReplacementResults() {
assert(!isEmpty() && "expect tile loop nest to be non-empty");
return loopOps.front()->getOpResults();
}
//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
@ -433,9 +443,13 @@ struct LinalgTileAndFuseTensorOps
"expect the tile interchange permutes the root loops");
// Tile `rootOp` and fuse its producers.
if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes,
rootInterchange)))
FailureOr<TileLoopNest> tileLoopNest =
tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange);
if (failed(tileLoopNest))
return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
// Replace all uses of the tiled loop operation.
rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
}
};
} // namespace