forked from OSchip/llvm-project
[mlir][linalg] Fold tensor.pad when inserting into linalg.fill
Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into tensor.insert_slice(<input>, linalg.fill) if the padding value and the filling value are the same. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D120410
This commit is contained in:
parent
1521162d78
commit
5d47332783
|
@ -444,13 +444,71 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
|
||||
/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
|
||||
/// filling value are the same.
|
||||
struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcPadOp = insertOp.source().getDefiningOp<tensor::PadOp>();
|
||||
if (!srcPadOp)
|
||||
return failure();
|
||||
|
||||
auto dstFillOp = insertOp.dest().getDefiningOp<linalg::FillOp>();
|
||||
if (!dstFillOp)
|
||||
return failure();
|
||||
|
||||
// We can only fold if the padding value is the same as the original
|
||||
// filling value.
|
||||
Value padValue = srcPadOp.getConstantPaddingValue();
|
||||
if (!padValue || dstFillOp.value() != padValue)
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
|
||||
SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
|
||||
|
||||
Location loc = insertOp.getLoc();
|
||||
MLIRContext *context = getContext();
|
||||
|
||||
AffineExpr sym0, sym1;
|
||||
bindSymbols(context, sym0, sym1);
|
||||
auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
|
||||
|
||||
// Calculate the new offsets for the insert. It should be the old offsets
|
||||
// plus low padding sizes.
|
||||
SmallVector<OpFoldResult, 4> newOffsets;
|
||||
for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
|
||||
Value padValue = getValueOrCreateConstantIndexOp(
|
||||
rewriter, srcPadOp.getLoc(), std::get<0>(p));
|
||||
Value offsetValue = getValueOrCreateConstantIndexOp(
|
||||
rewriter, insertOp.getLoc(), std::get<1>(p));
|
||||
newOffsets.push_back(
|
||||
applyMapToValues(rewriter, loc, addMap, {offsetValue, padValue})[0]);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> newSizes;
|
||||
for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
|
||||
newSizes.push_back(
|
||||
rewriter.create<tensor::DimOp>(loc, srcPadOp.source(), i).result());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
|
||||
insertOp, srcPadOp.source(), insertOp.dest(), newOffsets, newSizes,
|
||||
insertOp.getMixedStrides());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results
|
||||
.add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
||||
FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
|
||||
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
|
||||
FoldInsertPadIntoFill>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -613,3 +613,32 @@ func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index
|
|||
// CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>)
|
||||
// CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
|
||||
// CHECK-LABEL: func @insert_pad_into_fill
|
||||
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 384, 384]
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[F0]], %[[INIT]])
|
||||
// CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]]
|
||||
// CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<?x?x?xf32>
|
||||
// CHECK: tensor.insert_slice %[[INPUT]] into %[[FILL]][%[[LOW0]], %[[OFFSET1]], 2] [%[[D0]], %[[D1]], %[[D2]]] [1, 1, 1]
|
||||
func @insert_pad_into_fill(%input: tensor<?x?x?xf32>, %low0: index, %low1: index, %high1: index, %high2: index) -> tensor<8x384x384xf32> {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%pad = tensor.pad %input low[%low0, %low1, %c0] high[%c0, %high1, %high2] {
|
||||
^bb0(%arg3: index, %arg4: index, %arg5: index):
|
||||
tensor.yield %f0 : f32
|
||||
} : tensor<?x?x?xf32> to tensor<8x128x128xf32>
|
||||
%init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
|
||||
%fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
|
||||
%0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
|
||||
return %0: tensor<8x384x384xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue