[mlir][linalg] Adapt hoistPaddingOnTensors signature to support patterns (NFC).

Adapt hoistPaddingOnTensors to leave replacing and erasing the old pad tensor operation to the caller. This change makes the function pattern friendly.

Depends On D112003

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D112255
This commit is contained in:
Tobias Gysi 2021-10-29 06:42:32 +00:00
parent 95e6e1cc92
commit e83d8466fb
3 changed files with 28 additions and 26 deletions

View File

@ -9,8 +9,10 @@
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_
#define MLIR_DIALECT_LINALG_TRANSFORMS_HOIST_PADDING_H_
#include "mlir/Support/LogicalResult.h"
namespace mlir {
struct LogicalResult;
class Value;
namespace linalg {
class PadTensorOp;
@ -57,7 +59,8 @@ class PadTensorOp;
/// }
/// }
/// ```
LogicalResult hoistPaddingOnTensors(PadTensorOp &padTensorOp, int nLoops);
FailureOr<Value> hoistPaddingOnTensors(PadTensorOp opToHoist, int numLoops,
PadTensorOp &hoistedOp);
} // namespace linalg
} // namespace mlir

View File

@ -355,11 +355,12 @@ static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
ValueRange{ivVal, lbVal, stepVal});
}
LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
int nLoops) {
LLVM_DEBUG(DBGS() << "Try to hoist " << *(padTensorOp) << " by " << nLoops
FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(PadTensorOp opToHoist,
int numLoops,
PadTensorOp &hoistedOp) {
LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops
<< " loops\n");
HoistingAnalysis analysis(padTensorOp, nLoops);
HoistingAnalysis analysis(opToHoist, numLoops);
if (!analysis.isValid()) {
LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n");
return failure();
@ -376,8 +377,8 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
// Update actual number of loops, which may be smaller.
int nPackedLoops = analysis.packingLoops.size();
Location loc = padTensorOp->getLoc();
RankedTensorType paddedTensorType = padTensorOp.getResultType();
Location loc = opToHoist->getLoc();
RankedTensorType paddedTensorType = opToHoist.getResultType();
int paddedRank = paddedTensorType.getRank();
// Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
@ -404,8 +405,8 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
clonedLoopIvs.reserve(nPackedLoops);
leadingPackedTensorIndexings.reserve(nPackedLoops);
BlockAndValueMapping bvm;
// Insert `padTensorOp` into the backwardSlice so we clone it too.
analysis.backwardSlice.insert(padTensorOp);
// Insert `opToHoist` into the backwardSlice so we clone it too.
analysis.backwardSlice.insert(opToHoist);
// Stack step 1. iteratively clone loops and push `packedTensor`.
for (Operation *op : analysis.backwardSlice) {
// Specifically sit out in the extract_slice(packedTensor) case: this is the
@ -466,7 +467,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
b.getIndexAttr(1));
Value inserted =
b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()),
b.create<tensor::InsertSliceOp>(loc, bvm.lookup(opToHoist.result()),
packedTensor, offsets, sizes, strides);
// Stack step 3. iteratively pop the stack and propagate the yield.
@ -480,7 +481,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
// Now the packed tensor is ready, replace the original padding op by a
// 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
b.setInsertionPoint(padTensorOp);
b.setInsertionPoint(opToHoist);
SmallVector<Value> loopIterationCounts = llvm::to_vector<4>(
llvm::map_range(analysis.packingLoops, [&](Operation *loop) {
return buildLoopIterationCount(b, outer, cast<scf::ForOp>(loop));
@ -495,18 +496,10 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
// strides = [1 .. 1] (defined above)
packedTensor =
scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
padTensorOp.replaceAllUsesWith(
b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(),
packedTensor, offsets, sizes, strides)
->getResult(0));
Value newResult = b.create<tensor::ExtractSliceOp>(
loc, opToHoist.getResultType(), packedTensor, offsets, sizes, strides);
Operation *toErase = padTensorOp;
// Make the newly cloned `padTensorOp` available to the caller.
padTensorOp =
cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
toErase->erase();
return success();
// Make the newly cloned `opToHoist` available to the caller.
hoistedOp = cast<PadTensorOp>(bvm.lookup(opToHoist.result()).getDefiningOp());
return newResult;
}

View File

@ -771,7 +771,13 @@ void TestLinalgTransforms::runOnFunction() {
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
if (testHoistPadding) {
getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
(void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
PadTensorOp hoistedOp;
FailureOr<Value> newResult = linalg::hoistPaddingOnTensors(
padTensorOp, testHoistPadding, hoistedOp);
if (succeeded(newResult)) {
padTensorOp.getResult().replaceAllUsesWith(newResult.getValue());
padTensorOp->erase();
}
});
}
if (testInterchangePattern.hasValue())