forked from OSchip/llvm-project
[mlir][linalg] Make fusion on tensor rewriter friendly (NFC).
Let the calling pass or pattern replace the uses of the original root operation. Internally, the tileAndFuse still replaces uses and updates operands but only of newly created operations. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110169
This commit is contained in:
parent
d5629b5d4d
commit
e158b5634a
|
@ -199,9 +199,11 @@ public:
|
|||
|
||||
/// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
|
||||
/// fused producer of fails if fusion is not possible.
|
||||
// TODO: add replace uses callback to support passes and patterns.
|
||||
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
|
||||
|
||||
/// Returns the replacement results for the original untiled root operation.
|
||||
ValueRange getRootOpReplacementResults();
|
||||
|
||||
/// Returns the tiled root operation.
|
||||
LinalgOp getRootOp() { return rootOp; }
|
||||
|
||||
|
|
|
@ -245,10 +245,15 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
|
|||
.setLoopType(LinalgTilingLoopType::Loops);
|
||||
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
|
||||
|
||||
// Replace all uses of the root operation.
|
||||
// Exit if tiling the root operation fails.
|
||||
if (!tiledRootOp.hasValue())
|
||||
return failure();
|
||||
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
|
||||
|
||||
// Replace all uses of the root operation if it has been tiled before. All
|
||||
// uses of the original untiled root operation are updated by the calling pass
|
||||
// or pattern.
|
||||
if (!isEmpty())
|
||||
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
|
||||
|
||||
// Update the root operation and append the loops and tile loop dimensions.
|
||||
rootOp = tiledRootOp->op;
|
||||
|
@ -323,6 +328,11 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
|
|||
return clonedOp;
|
||||
}
|
||||
|
||||
ValueRange TileLoopNest::getRootOpReplacementResults() {
|
||||
assert(!isEmpty() && "expect tile loop nest to be non-empty");
|
||||
return loopOps.front()->getOpResults();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tile and fuse entry-points.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -433,9 +443,13 @@ struct LinalgTileAndFuseTensorOps
|
|||
"expect the tile interchange permutes the root loops");
|
||||
|
||||
// Tile `rootOp` and fuse its producers.
|
||||
if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes,
|
||||
rootInterchange)))
|
||||
FailureOr<TileLoopNest> tileLoopNest =
|
||||
tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange);
|
||||
if (failed(tileLoopNest))
|
||||
return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
|
||||
|
||||
// Replace all uses of the tiled loop operation.
|
||||
rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue