forked from OSchip/llvm-project
[mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.
This patch folds InsertStridedSliceOp(SplatOp(X):src_type, SplatOp(X):dst_type) to SplatOp(X):dst_type. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D128891
This commit is contained in:
parent
2ceb9c347f
commit
91ab4d4231
|
@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp :
|
|||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Vector_OuterProductOp :
|
||||
|
|
|
@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
|
||||
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
|
||||
class FoldInsertStridedSliceSplat final
|
||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcSplatOp =
|
||||
insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
|
||||
auto destSplatOp =
|
||||
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
|
||||
|
||||
if (!srcSplatOp || !destSplatOp)
|
||||
return failure();
|
||||
|
||||
if (srcSplatOp.getInput() != destSplatOp.getInput())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results, MLIRContext *context) {
|
||||
results.add<FoldInsertStridedSliceSplat>(context);
|
||||
}
|
||||
|
||||
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getSourceVectorType() == getDestVectorType())
|
||||
return getSource();
|
||||
|
|
|
@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
|
|||
%1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
|
||||
return %1 : vector<4x16xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_strided_slice_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: f32)
|
||||
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
|
||||
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
|
||||
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
|
||||
%splat0 = vector.splat %x : vector<4x4xf32>
|
||||
%splat1 = vector.splat %x : vector<8x16xf32>
|
||||
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
|
||||
: vector<4x4xf32> into vector<8x16xf32>
|
||||
return %0 : vector<8x16xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue