[mlir][linalg] Fold tensor.pad(linalg.fill) with the same value

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D119160
This commit is contained in:
Lei Zhang 2022-02-10 08:18:00 -05:00
parent 9b5a3d14b2
commit 06a0385142
2 changed files with 107 additions and 2 deletions

View File

@ -441,12 +441,52 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
}
};
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
/// filling value are the same.
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
auto fillOp = padOp.source().getDefiningOp<linalg::FillOp>();
if (!fillOp)
return failure();
// We can only fold if the padding value is the same as the original
// filling value.
Value padValue = padOp.getConstantPaddingValue();
if (!padValue || fillOp.value() != padValue)
return failure();
ReifiedRankedShapedTypeDims reifiedShape;
ReifyRankedShapedTypeOpInterface interface =
cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
return rewriter.notifyMatchFailure(
padOp, "failed to reify tensor.pad op result shape");
auto oldResultType = padOp.getResultType();
SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
ShapedType::kDynamicSize);
auto newInitOp = rewriter.create<InitTensorOp>(
padOp.getLoc(), reifiedShape.front(), staticShape,
oldResultType.getElementType());
auto newFillOp =
rewriter.create<FillOp>(fillOp.getLoc(), padValue, newInitOp);
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
newFillOp.result());
return success();
}
};
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
results
.add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -585,3 +585,68 @@ func @fold_self_copy(%0 : memref<4x16xf32>) {
}
return
}
// -----
// CHECK-LABEL: func @fold_static_pad_fill
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
// CHECK: return %[[FILL]]
func @fold_static_pad_fill() -> tensor<412x276xf32> {
%f0 = arith.constant 0.0 : f32
%init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
%fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
%pad = tensor.pad %fill low[4, 1] high[8, 2] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %f0 : f32
} : tensor<400x273xf32> to tensor<412x276xf32>
return %pad : tensor<412x276xf32>
}
// -----
// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)>
// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)>
// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)>
// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)>
// CHECK: func @fold_dynamic_pad_fill
// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index
// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32>
// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
// CHECK: return %[[FILL]]
func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> {
%f0 = arith.constant 0.0 : f32
%fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32>
%pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %f0 : f32
} : tensor<8x?x16x32xf32> to tensor<?x?x?x?xf32>
return %pad : tensor<?x?x?x?xf32>
}
// -----
// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch
func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
%f0 = arith.constant 0.0 : f32
%f1 = arith.constant 1.0 : f32
%init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
%fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
// CHECK: tensor.pad
%pad = tensor.pad %fill low[4, 1] high[8, 2] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %f1 : f32
} : tensor<400x273xf32> to tensor<412x276xf32>
return %pad : tensor<412x276xf32>
}