forked from OSchip/llvm-project
[mlir][Linalg] Use reify for padded op shape derivation.
Previously, we would insert a DimOp and rely on later canonicalizations. Unfortunately, reifyShape kind of rewrites are not canonicalizations anymore. This introduces undesirable pass dependencies. Instead, immediately reify the result shape and avoid the DimOp altogether. This is akin to a local folding, which avoids introducing more reliance on `-resolve-shaped-type-result-dims` (similar to compositions of `affine.apply` by construction to avoid chains of size > 1). It does not completely get rid of the reliance on the pass as the process is merely local: calling the pass may still be necessary for global effects. Indeed, one of the tests still requires the pass. Differential Revision: https://reviews.llvm.org/D109571
This commit is contained in:
parent
0213d7ec0c
commit
b01d223faf
|
@ -185,6 +185,13 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
|||
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> reifiedResultShapes;
|
||||
if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
|
||||
.reifyResultShapes(rewriter, reifiedResultShapes)))
|
||||
return failure();
|
||||
assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
|
||||
"expected same number of results");
|
||||
|
||||
// Clone `opToPad` to operate on the statically padded shapes.
|
||||
auto resultTensorTypes =
|
||||
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
|
||||
|
@ -192,28 +199,21 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
|||
|
||||
// Recover the slice out of the new static results. This keeps the original
|
||||
// linalg op around because it uses the dims of the original results.
|
||||
// This later folds away.
|
||||
SmallVector<Value> paddedSubviewResults;
|
||||
paddedSubviewResults.reserve(opToPad->getNumResults());
|
||||
SetVector<Operation *> newUsersOfOpToPad;
|
||||
for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
|
||||
auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
|
||||
for (auto en : llvm::enumerate(paddedOp->getResults())) {
|
||||
Value paddedResult = en.value();
|
||||
int64_t resultNumber = en.index();
|
||||
int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
|
||||
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
|
||||
auto sizes = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
|
||||
auto dimOp = rewriter.create<tensor::DimOp>(loc, std::get<0>(it), d);
|
||||
newUsersOfOpToPad.insert(dimOp);
|
||||
return dimOp.getResult();
|
||||
}));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
for (Value v : reifiedResultShapes[resultNumber])
|
||||
sizes.push_back(v);
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, std::get<1>(it), offsets, sizes, strides));
|
||||
loc, paddedResult, offsets, sizes, strides));
|
||||
}
|
||||
// Replace the transient `opToPad` locally, except for uses that we just
|
||||
// created for the purpose of extracting the dims.
|
||||
rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
|
||||
return !newUsersOfOpToPad.contains(opOp.getOwner());
|
||||
});
|
||||
rewriter.replaceOp(opToPad, paddedSubviewResults);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -244,14 +244,16 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
|||
return failure();
|
||||
|
||||
// Setup RAII guard to return properly.
|
||||
LinalgOp paddedOp;
|
||||
LinalgOp tiledOp = res->op;
|
||||
auto guard = llvm::make_scope_exit([&]() {
|
||||
// Return relevant information to derived pattern.
|
||||
result = *res;
|
||||
// Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
|
||||
filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
|
||||
if (tiledOp != res->op)
|
||||
filter.replaceLinalgTransformationFilter(rewriter, res->op);
|
||||
// Update filter.
|
||||
if (paddedOp)
|
||||
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
|
||||
else
|
||||
filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
|
||||
});
|
||||
|
||||
// Consider padding on the fly only if the op has tensor semantics.
|
||||
|
@ -261,7 +263,6 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
|||
|
||||
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
|
||||
// `res.op` is rewritten in static form with padded operands.
|
||||
LinalgOp paddedOp;
|
||||
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
|
||||
options.paddingValueComputationFunction,
|
||||
paddedOp))) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
|
||||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \
|
||||
// RUN: FileCheck %s -check-prefix=TILE2
|
||||
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
|
||||
// RUN: FileCheck %s -check-prefix=TILE1
|
||||
|
|
Loading…
Reference in New Issue