[linalg] Expose `rewriteAsPaddedOp` function.

Differential Revision: https://reviews.llvm.org/D107629
This commit is contained in:
Alexander Belyaev 2021-08-06 11:28:11 +02:00
parent a5a2f05dcc
commit aa2210a830
2 changed files with 27 additions and 22 deletions

View File

@ -886,6 +886,13 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
PatternRewriter &rewriter) const override;
};
/// Try to create a static bounding box around each operand of `opToPad`.
/// If successful, `paddedOp` will be updated to the cloned static form.
LogicalResult
rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
const PaddingValueComputationFunction &paddingFunc,
LinalgOp &paddedOp);
using OptimizeCopyFn =
std::function<LogicalResult(PatternRewriter &, PadTensorOp, Value)>;

View File

@ -126,7 +126,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
/// Return failure if the operand cannot be padded to a static shape.
static LogicalResult padOperandToSmallestStaticBoundingBox(
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
const LinalgTilingOptions &options, Value &result) {
const PaddingValueComputationFunction &paddingFunc, Value &result) {
// Already static shape, no need to pad.
if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
return success();
@ -148,7 +148,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
opToPad, "No constant bounding box can be found for padding");
staticSizes.push_back(indexAttr.getInt());
}
Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
Value pad = paddingFunc(rewriter, *opOperand);
auto staticTensorType = RankedTensorType::get(
staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
@ -156,13 +156,10 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
return success();
}
// Try to create a static bounding box around each operand of `res.op`.
// If successful, `res.op` is rewritten in static form with padded operands.
// `res.op` is updated to the cloned static form of the op on success.
static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
TiledLinalgOp &res,
const LinalgTilingOptions &options) {
LinalgOp opToPad = res.op;
LogicalResult
linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
const PaddingValueComputationFunction &paddingFunc,
LinalgOp &paddedOp) {
Location loc = opToPad->getLoc();
// If the op is fully static, it does not need padding.
@ -183,7 +180,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
// If padding was requested but the shape cannot be bounded statically then
// the pattern fails to apply.
if (failed(padOperandToSmallestStaticBoundingBox(
rewriter, opToPad, opOperand, options, paddedOperand)))
rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
return failure();
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
}
@ -191,8 +188,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
linalg::LinalgOp paddedOp =
opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
// Recover the slice out of the new static results. This keeps the original
// linalg op around because it uses the dims of the original results.
@ -218,8 +214,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
return !newUsersOfOpToPad.contains(opOp.getOwner());
});
res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
return success();
}
@ -265,15 +259,19 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
!linalgOp.hasTensorSemantics())
return success();
// Try to pad on the fly by rewriting res->op as a padded op.
if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
// Set so RAII guard does not propagate TiledLinalgOp to `result`.
return failure();
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
// `res.op` is rewritten in static form with padded operands.
LinalgOp paddedOp;
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
options.paddingValueComputationFunction,
paddedOp))) {
res->op = paddedOp;
// Do not perform replacement of `linalgOp`, let the derived patterns
// do this as they see fit, from the resulting TiledLinalgOp.
return success();
}
// Do not perform replacement of `linalgOp`, let the derived patterns
// do this as they see fit, from the resulting TiledLinalgOp.
return success();
// Set so RAII guard does not propagate TiledLinalgOp to `result`.
return failure();
}
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {