forked from OSchip/llvm-project
[linalg] Expose `rewriteAsPaddedOp` function.
Differential Revision: https://reviews.llvm.org/D107629
This commit is contained in:
parent
a5a2f05dcc
commit
aa2210a830
|
@ -886,6 +886,13 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Try to create a static bounding box around each operand of `opToPad`.
|
||||
/// If successful, `paddedOp` will be updated to the cloned static form.
|
||||
LogicalResult
|
||||
rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
||||
const PaddingValueComputationFunction &paddingFunc,
|
||||
LinalgOp &paddedOp);
|
||||
|
||||
using OptimizeCopyFn =
|
||||
std::function<LogicalResult(PatternRewriter &, PadTensorOp, Value)>;
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
|
|||
/// Return failure if the operand cannot be padded to a static shape.
|
||||
static LogicalResult padOperandToSmallestStaticBoundingBox(
|
||||
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
|
||||
const LinalgTilingOptions &options, Value &result) {
|
||||
const PaddingValueComputationFunction &paddingFunc, Value &result) {
|
||||
// Already static shape, no need to pad.
|
||||
if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
|
||||
return success();
|
||||
|
@ -148,7 +148,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
|
|||
opToPad, "No constant bounding box can be found for padding");
|
||||
staticSizes.push_back(indexAttr.getInt());
|
||||
}
|
||||
Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
|
||||
Value pad = paddingFunc(rewriter, *opOperand);
|
||||
auto staticTensorType = RankedTensorType::get(
|
||||
staticSizes, getElementTypeOrSelf(opOperand->get()));
|
||||
result = linalg::PadTensorOp::createPadHighOp(
|
||||
|
@ -156,13 +156,10 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Try to create a static bounding box around each operand of `res.op`.
|
||||
// If successful, `res.op` is rewritten in static form with padded operands.
|
||||
// `res.op` is updated to the cloned static form of the op on success.
|
||||
static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
||||
TiledLinalgOp &res,
|
||||
const LinalgTilingOptions &options) {
|
||||
LinalgOp opToPad = res.op;
|
||||
LogicalResult
|
||||
linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
||||
const PaddingValueComputationFunction &paddingFunc,
|
||||
LinalgOp &paddedOp) {
|
||||
Location loc = opToPad->getLoc();
|
||||
|
||||
// If the op is fully static, it does not need padding.
|
||||
|
@ -183,7 +180,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
|||
// If padding was requested but the shape cannot be bounded statically then
|
||||
// the pattern fails to apply.
|
||||
if (failed(padOperandToSmallestStaticBoundingBox(
|
||||
rewriter, opToPad, opOperand, options, paddedOperand)))
|
||||
rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
|
||||
return failure();
|
||||
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
|
||||
}
|
||||
|
@ -191,8 +188,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
|||
// Clone `opToPad` to operate on the statically padded shapes.
|
||||
auto resultTensorTypes =
|
||||
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
|
||||
linalg::LinalgOp paddedOp =
|
||||
opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
|
||||
paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
|
||||
|
||||
// Recover the slice out of the new static results. This keeps the original
|
||||
// linalg op around because it uses the dims of the original results.
|
||||
|
@ -218,8 +214,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
|||
rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
|
||||
return !newUsersOfOpToPad.contains(opOp.getOwner());
|
||||
});
|
||||
|
||||
res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -265,15 +259,19 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
|||
!linalgOp.hasTensorSemantics())
|
||||
return success();
|
||||
|
||||
// Try to pad on the fly by rewriting res->op as a padded op.
|
||||
if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
|
||||
// Set so RAII guard does not propagate TiledLinalgOp to `result`.
|
||||
return failure();
|
||||
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
|
||||
// `res.op` is rewritten in static form with padded operands.
|
||||
LinalgOp paddedOp;
|
||||
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
|
||||
options.paddingValueComputationFunction,
|
||||
paddedOp))) {
|
||||
res->op = paddedOp;
|
||||
// Do not perform replacement of `linalgOp`, let the derived patterns
|
||||
// do this as they see fit, from the resulting TiledLinalgOp.
|
||||
return success();
|
||||
}
|
||||
|
||||
// Do not perform replacement of `linalgOp`, let the derived patterns
|
||||
// do this as they see fit, from the resulting TiledLinalgOp.
|
||||
return success();
|
||||
// Set so RAII guard does not propagate TiledLinalgOp to `result`.
|
||||
return failure();
|
||||
}
|
||||
|
||||
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
|
||||
|
|
Loading…
Reference in New Issue