forked from OSchip/llvm-project
[mlir][linalg] Adapt hoistPaddingOnTensors signature to support patterns (NFC).
Adapt hoistPaddingOnTensors to leave replacing and erasing the old pad tensor operation to the caller. This change makes the function pattern friendly. Depends On D112003 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D112255
This commit is contained in:
parent
95e6e1cc92
commit
e83d8466fb
|
@ -9,8 +9,10 @@
|
|||
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_
|
||||
#define MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_
|
||||
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
struct LogicalResult;
|
||||
class Value;
|
||||
|
||||
namespace linalg {
|
||||
class PadTensorOp;
|
||||
|
@ -57,7 +59,8 @@ class PadTensorOp;
|
|||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
LogicalResult hoistPaddingOnTensors(PadTensorOp &padTensorOp, int nLoops);
|
||||
FailureOr<Value> hoistPaddingOnTensors(PadTensorOp opToHoist, int numLoops,
|
||||
PadTensorOp &hoistedOp);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -355,11 +355,12 @@ static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
|
|||
ValueRange{ivVal, lbVal, stepVal});
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
|
||||
int nLoops) {
|
||||
LLVM_DEBUG(DBGS() << "Try to hoist " << *(padTensorOp) << " by " << nLoops
|
||||
FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(PadTensorOp opToHoist,
|
||||
int numLoops,
|
||||
PadTensorOp &hoistedOp) {
|
||||
LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops
|
||||
<< " loops\n");
|
||||
HoistingAnalysis analysis(padTensorOp, nLoops);
|
||||
HoistingAnalysis analysis(opToHoist, numLoops);
|
||||
if (!analysis.isValid()) {
|
||||
LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n");
|
||||
return failure();
|
||||
|
@ -376,8 +377,8 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
|
|||
// Update actual number of loops, which may be smaller.
|
||||
int nPackedLoops = analysis.packingLoops.size();
|
||||
|
||||
Location loc = padTensorOp->getLoc();
|
||||
RankedTensorType paddedTensorType = padTensorOp.getResultType();
|
||||
Location loc = opToHoist->getLoc();
|
||||
RankedTensorType paddedTensorType = opToHoist.getResultType();
|
||||
int paddedRank = paddedTensorType.getRank();
|
||||
|
||||
// Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
|
||||
|
@ -404,8 +405,8 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
|
|||
clonedLoopIvs.reserve(nPackedLoops);
|
||||
leadingPackedTensorIndexings.reserve(nPackedLoops);
|
||||
BlockAndValueMapping bvm;
|
||||
// Insert `padTensorOp` into the backwardSlice so we clone it too.
|
||||
analysis.backwardSlice.insert(padTensorOp);
|
||||
// Insert `opToHoist` into the backwardSlice so we clone it too.
|
||||
analysis.backwardSlice.insert(opToHoist);
|
||||
// Stack step 1. iteratively clone loops and push `packedTensor`.
|
||||
for (Operation *op : analysis.backwardSlice) {
|
||||
// Specifically sit out in the extract_slice(packedTensor) case: this is the
|
||||
|
@ -466,7 +467,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
|
|||
b.getIndexAttr(1));
|
||||
|
||||
Value inserted =
|
||||
b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()),
|
||||
b.create<tensor::InsertSliceOp>(loc, bvm.lookup(opToHoist.result()),
|
||||
packedTensor, offsets, sizes, strides);
|
||||
|
||||
// Stack step 3. iteratively pop the stack and propagate the yield.
|
||||
|
@ -480,7 +481,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
|
|||
|
||||
// Now the packed tensor is ready, replace the original padding op by a
|
||||
// 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
|
||||
b.setInsertionPoint(padTensorOp);
|
||||
b.setInsertionPoint(opToHoist);
|
||||
SmallVector<Value> loopIterationCounts = llvm::to_vector<4>(
|
||||
llvm::map_range(analysis.packingLoops, [&](Operation *loop) {
|
||||
return buildLoopIterationCount(b, outer, cast<scf::ForOp>(loop));
|
||||
|
@ -495,18 +496,10 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
|
|||
// strides = [1 .. 1] (defined above)
|
||||
packedTensor =
|
||||
scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
|
||||
padTensorOp.replaceAllUsesWith(
|
||||
b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(),
|
||||
packedTensor, offsets, sizes, strides)
|
||||
->getResult(0));
|
||||
Value newResult = b.create<tensor::ExtractSliceOp>(
|
||||
loc, opToHoist.getResultType(), packedTensor, offsets, sizes, strides);
|
||||
|
||||
Operation *toErase = padTensorOp;
|
||||
|
||||
// Make the newly cloned `padTensorOp` available to the caller.
|
||||
padTensorOp =
|
||||
cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
|
||||
|
||||
toErase->erase();
|
||||
|
||||
return success();
|
||||
// Make the newly cloned `opToHoist` available to the caller.
|
||||
hoistedOp = cast<PadTensorOp>(bvm.lookup(opToHoist.result()).getDefiningOp());
|
||||
return newResult;
|
||||
}
|
||||
|
|
|
@ -771,7 +771,13 @@ void TestLinalgTransforms::runOnFunction() {
|
|||
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
|
||||
if (testHoistPadding) {
|
||||
getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
|
||||
(void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
|
||||
PadTensorOp hoistedOp;
|
||||
FailureOr<Value> newResult = linalg::hoistPaddingOnTensors(
|
||||
padTensorOp, testHoistPadding, hoistedOp);
|
||||
if (succeeded(newResult)) {
|
||||
padTensorOp.getResult().replaceAllUsesWith(newResult.getValue());
|
||||
padTensorOp->erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
if (testInterchangePattern.hasValue())
|
||||
|
|
Loading…
Reference in New Issue