diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 0a556b3d99eb..e0ce70464836 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -444,13 +444,71 @@ struct FoldFillWithPad final : public OpRewritePattern { } }; +/// Fold tensor.insert_slice(tensor.pad(), linalg.fill) into +/// tensor.insert_slice(, linalg.fill) if the padding value and the +/// filling value are the same. +struct FoldInsertPadIntoFill : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, + PatternRewriter &rewriter) const override { + auto srcPadOp = insertOp.source().getDefiningOp(); + if (!srcPadOp) + return failure(); + + auto dstFillOp = insertOp.dest().getDefiningOp(); + if (!dstFillOp) + return failure(); + + // We can only fold if the padding value is the same as the original + // filling value. + Value padValue = srcPadOp.getConstantPaddingValue(); + if (!padValue || dstFillOp.value() != padValue) + return failure(); + + SmallVector lowPads = srcPadOp.getMixedLowPad(); + SmallVector oldOffsets = insertOp.getMixedOffsets(); + + Location loc = insertOp.getLoc(); + MLIRContext *context = getContext(); + + AffineExpr sym0, sym1; + bindSymbols(context, sym0, sym1); + auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); + + // Calculate the new offsets for the insert. It should be the old offsets + // plus low padding sizes. + SmallVector newOffsets; + for (const auto &p : llvm::zip(lowPads, oldOffsets)) { + Value padValue = getValueOrCreateConstantIndexOp( + rewriter, srcPadOp.getLoc(), std::get<0>(p)); + Value offsetValue = getValueOrCreateConstantIndexOp( + rewriter, insertOp.getLoc(), std::get<1>(p)); + newOffsets.push_back( + applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]); + } + + SmallVector newSizes; + for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) { + newSizes.push_back( + rewriter.create(loc, srcPadOp.source(), i).result()); + } + + rewriter.replaceOpWithNewOp( + insertOp, srcPadOp.source(), insertOp.dest(), newOffsets, newSizes, + insertOp.getMixedStrides()); + return success(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results .add, - FoldFillWithTensorReshape>(context); + FoldFillWithTensorReshape, + FoldInsertPadIntoFill>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index e3f213f8cd6e..ca5546bd2697 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -613,3 +613,32 @@ func @cast_dest(%arg0: tensor, %arg1: tensor<1x?x?xf32>, %arg2: index // CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>) // CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor } + +// ----- + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: func @insert_pad_into_fill +// CHECK-SAME: (%[[INPUT:.+]]: tensor, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 384, 384] +// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]]) +// CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]] +// CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor +// CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor +// CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]] : tensor +// CHECK: tensor.insert_slice %[[INPUT]] into %[[FILL]][%[[LOW0]], %[[OFFSET1]], 2] [%[[D0]], %[[D1]], %[[D2]]] [1, 1, 1] +func @insert_pad_into_fill(%input: tensor, %low0: index, %low1: index, %high1: index, %high2: index) -> tensor<8x384x384xf32> { + %f0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %pad = tensor.pad %input low[%low0, %low1, %c0] high[%c0, %high1, %high2] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %f0 : f32 + } : tensor to tensor<8x128x128xf32> + %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32> + %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32> + %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> + return %0: tensor<8x384x384xf32> +}