[mlir][Vector] Fold ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

This patch folds ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

Differential Revision: https://reviews.llvm.org/D128969
This commit is contained in:
jacquesguan 2022-07-01 15:08:58 +08:00
parent 2c3784cff8
commit e98e13ac8f
3 changed files with 45 additions and 0 deletions

View File

@ -487,6 +487,7 @@ def Vector_ShuffleOp :
}];
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def Vector_ExtractElementOp :

View File

@ -1882,6 +1882,36 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
return DenseElementsAttr::get(getVectorType(), results);
}
namespace {
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
using OpRewritePattern<ShuffleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
if (!v1Splat || !v2Splat)
return failure();
if (v1Splat.getInput() != v2Splat.getInput())
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
return success();
}
};
} // namespace
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ShuffleSplat>(context);
}
//===----------------------------------------------------------------------===//
// InsertElementOp
//===----------------------------------------------------------------------===//

View File

@ -1655,3 +1655,17 @@ func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf3
: vector<2x4xf32> into vector<8x16xf32>
return %1 : vector<8x16xf32>
}
// -----
// CHECK-LABEL: func @shuffle_splat
// CHECK-SAME: (%[[ARG:.*]]: i32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
%v0 = vector.splat %x : vector<4xi32>
%v1 = vector.splat %x : vector<2xi32>
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<4xi32>
}