[mlir][Linalg] Use reify for padded op shape derivation.

Previously, we would insert a DimOp and rely on later canonicalizations.
Unfortunately, reifyShape kind of rewrites are not canonicalizations anymore.
This introduces undesirable pass dependencies.

Instead, immediately reify the result shape and avoid the DimOp altogether.
This is akin to a local folding, which avoids introducing more reliance on `-resolve-shaped-type-result-dims` (similar to compositions of `affine.apply` by construction to avoid chains of size > 1).

It does not completely get rid of the reliance on the pass as the process is merely local: calling the pass may still be necessary for global effects. Indeed, one of the tests still requires the pass.

Differential Revision: https://reviews.llvm.org/D109571
This commit is contained in:
Nicolas Vasilache 2021-09-10 07:12:14 +00:00
parent 0213d7ec0c
commit b01d223faf
2 changed files with 23 additions and 22 deletions

View File

@ -185,6 +185,13 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
}
SmallVector<SmallVector<Value>> reifiedResultShapes;
if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
.reifyResultShapes(rewriter, reifiedResultShapes)))
return failure();
assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
"expected same number of results");
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
@ -192,28 +199,21 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
// Recover the slice out of the new static results. This keeps the original
// linalg op around because it uses the dims of the original results.
// This later folds away.
SmallVector<Value> paddedSubviewResults;
paddedSubviewResults.reserve(opToPad->getNumResults());
SetVector<Operation *> newUsersOfOpToPad;
for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
for (auto en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
auto sizes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d);
newUsersOfOpToPad.insert(dimOp);
return dimOp.getResult();
}));
SmallVector<OpFoldResult> sizes;
for (Value v : reifiedResultShapes[resultNumber])
sizes.push_back(v);
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
loc, std::get<1>(it), offsets, sizes, strides));
loc, paddedResult, offsets, sizes, strides));
}
// Replace the transient `opToPad` locally, except for uses that we just
// created for the purpose of extracting the dims.
rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
return !newUsersOfOpToPad.contains(opOp.getOwner());
});
rewriter.replaceOp(opToPad, paddedSubviewResults);
return success();
}
@ -244,14 +244,16 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
return failure();
// Setup RAII guard to return properly.
LinalgOp paddedOp;
LinalgOp tiledOp = res->op;
auto guard = llvm::make_scope_exit([&]() {
// Return relevant information to derived pattern.
result = *res;
// Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
if (tiledOp != res->op)
filter.replaceLinalgTransformationFilter(rewriter, res->op);
// Update filter.
if (paddedOp)
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
else
filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
});
// Consider padding on the fly only if the op has tensor semantics.
@ -261,7 +263,6 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
// `res.op` is rewritten in static form with padded operands.
LinalgOp paddedOp;
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
options.paddingValueComputationFunction,
paddedOp))) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \
// RUN: FileCheck %s -check-prefix=TILE2
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
// RUN: FileCheck %s -check-prefix=TILE1