[mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

This patch folds InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

Differential Revision: https://reviews.llvm.org/D129058
This commit is contained in:
jacquesguan 2022-07-01 16:02:28 +08:00
parent 0880b9d526
commit cf74b7ec80
2 changed files with 35 additions and 1 deletions

View File

@ -2031,11 +2031,32 @@ public:
}
};
/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern<InsertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
if (!srcSplat || !dstSplat)
return failure();
if (srcSplat.getInput() != dstSplat.getInput())
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
return success();
}
};
} // namespace
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertToBroadcast, BroadcastFolder>(context);
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}
// Eliminates insert operations that produce values identical to their source

View File

@ -1669,3 +1669,16 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
return %shuffle : vector<4xi32>
}
// -----
// CHECK-LABEL: func @insert_splat
// CHECK-SAME: (%[[ARG:.*]]: i32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
%v0 = vector.splat %x : vector<4x3xi32>
%v1 = vector.splat %x : vector<2x4x3xi32>
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
return %insert : vector<2x4x3xi32>
}