forked from OSchip/llvm-project
[mlir][linalg] Fold tensor.pad(linalg.fill) with the same value
Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D119160
This commit is contained in:
parent
9b5a3d14b2
commit
06a0385142
|
@ -441,12 +441,52 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
|
||||
/// filling value are the same.
|
||||
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto fillOp = padOp.source().getDefiningOp<linalg::FillOp>();
|
||||
if (!fillOp)
|
||||
return failure();
|
||||
|
||||
// We can only fold if the padding value is the same as the original
|
||||
// filling value.
|
||||
Value padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue || fillOp.value() != padValue)
|
||||
return failure();
|
||||
|
||||
ReifiedRankedShapedTypeDims reifiedShape;
|
||||
ReifyRankedShapedTypeOpInterface interface =
|
||||
cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
|
||||
if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp, "failed to reify tensor.pad op result shape");
|
||||
|
||||
auto oldResultType = padOp.getResultType();
|
||||
SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
|
||||
ShapedType::kDynamicSize);
|
||||
auto newInitOp = rewriter.create<InitTensorOp>(
|
||||
padOp.getLoc(), reifiedShape.front(), staticShape,
|
||||
oldResultType.getElementType());
|
||||
auto newFillOp =
|
||||
rewriter.create<FillOp>(fillOp.getLoc(), padValue, newInitOp);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
|
||||
newFillOp.result());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
||||
FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
|
||||
results
|
||||
.add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
||||
FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -585,3 +585,68 @@ func @fold_self_copy(%0 : memref<4x16xf32>) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_static_pad_fill
|
||||
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [412, 276] : tensor<412x276xf32>
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
|
||||
// CHECK: return %[[FILL]]
|
||||
func @fold_static_pad_fill() -> tensor<412x276xf32> {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
|
||||
%fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
|
||||
%pad = tensor.pad %fill low[4, 1] high[8, 2] {
|
||||
^bb0(%arg1: index, %arg2: index):
|
||||
tensor.yield %f0 : f32
|
||||
} : tensor<400x273xf32> to tensor<412x276xf32>
|
||||
return %pad : tensor<412x276xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)>
|
||||
// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)>
|
||||
// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)>
|
||||
|
||||
// CHECK: func @fold_dynamic_pad_fill
|
||||
// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index
|
||||
|
||||
// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[OF:.+]] = linalg.fill(%[[F0]], %[[SRC]]) : f32, tensor<8x?x16x32xf32>
|
||||
// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
|
||||
// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
|
||||
// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
|
||||
// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
|
||||
// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
|
||||
// CHECK: return %[[FILL]]
|
||||
func @fold_dynamic_pad_fill(%init: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%fill = linalg.fill(%f0, %init) : f32, tensor<8x?x16x32xf32> -> tensor<8x?x16x32xf32>
|
||||
%pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
|
||||
tensor.yield %f0 : f32
|
||||
} : tensor<8x?x16x32xf32> to tensor<?x?x?x?xf32>
|
||||
return %pad : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch
|
||||
func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%f1 = arith.constant 1.0 : f32
|
||||
%init = linalg.init_tensor [400, 273] : tensor<400x273xf32>
|
||||
%fill = linalg.fill(%f0, %init) : f32, tensor<400x273xf32> -> tensor<400x273xf32>
|
||||
// CHECK: tensor.pad
|
||||
%pad = tensor.pad %fill low[4, 1] high[8, 2] {
|
||||
^bb0(%arg1: index, %arg2: index):
|
||||
tensor.yield %f1 : f32
|
||||
} : tensor<400x273xf32> to tensor<412x276xf32>
|
||||
return %pad : tensor<412x276xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue