forked from OSchip/llvm-project
[mlir][Vector] Fold InsertStridedSliceOp of ExtractStridedSliceOp.
This patch supports to fold InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) to dst. Differential Revision: https://reviews.llvm.org/D128903
This commit is contained in:
parent
91ab4d4231
commit
8f45c5862f
|
@ -2205,11 +2205,43 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
|
||||
/// to dst.
|
||||
class FoldInsertStridedSliceOfExtract final
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto extractStridedSliceOp =
|
||||
insertStridedSliceOp.getSource()
|
||||
.getDefiningOp<vector::ExtractStridedSliceOp>();
|
||||
|
||||
if (!extractStridedSliceOp)
|
||||
return failure();
|
||||
|
||||
if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
|
||||
return failure();
|
||||
|
||||
// Check if have the same strides and offsets.
|
||||
if (extractStridedSliceOp.getStrides() !=
|
||||
insertStridedSliceOp.getStrides() ||
|
||||
extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<FoldInsertStridedSliceSplat>(context);
|
||||
results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
|
||||
context);
|
||||
}
|
||||
|
||||
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -1641,3 +1641,17 @@ func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
|
|||
: vector<4x4xf32> into vector<8x16xf32>
|
||||
return %0 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_extract_strided_slice
|
||||
// CHECK-SAME: (%[[ARG:.*]]: vector<8x16xf32>)
|
||||
// CHECK-NEXT: return %[[ARG]] : vector<8x16xf32>
|
||||
func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf32>) {
|
||||
%0 = vector.extract_strided_slice %x {offsets = [0, 8], sizes = [2, 4], strides = [1, 1]}
|
||||
: vector<8x16xf32> to vector<2x4xf32>
|
||||
%1 = vector.insert_strided_slice %0, %x {offsets = [0, 8], strides = [1, 1]}
|
||||
: vector<2x4xf32> into vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue