forked from OSchip/llvm-project
[mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
This patch folds InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X). Differential Revision: https://reviews.llvm.org/D129058
This commit is contained in:
parent
0880b9d526
commit
cf74b7ec80
|
@ -2031,11 +2031,32 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
|
||||
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
|
||||
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
|
||||
|
||||
if (!srcSplat || !dstSplat)
|
||||
return failure();
|
||||
|
||||
if (srcSplat.getInput() != dstSplat.getInput())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<InsertToBroadcast, BroadcastFolder>(context);
|
||||
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
|
||||
}
|
||||
|
||||
// Eliminates insert operations that produce values identical to their source
|
||||
|
|
|
@ -1669,3 +1669,16 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
|
|||
return %shuffle : vector<4xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32)
|
||||
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
|
||||
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
|
||||
func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
|
||||
%v0 = vector.splat %x : vector<4x3xi32>
|
||||
%v1 = vector.splat %x : vector<2x4x3xi32>
|
||||
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
|
||||
return %insert : vector<2x4x3xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue