forked from OSchip/llvm-project
[mlir][Vector] Fold ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
This patch folds ShuffleOp(SplatOp(X), SplatOp(X)) to SplatOp(X). Differential Revision: https://reviews.llvm.org/D128969
This commit is contained in:
parent
2c3784cff8
commit
e98e13ac8f
|
@ -487,6 +487,7 @@ def Vector_ShuffleOp :
|
|||
}];
|
||||
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Vector_ExtractElementOp :
|
||||
|
|
|
@ -1882,6 +1882,36 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
|
|||
return DenseElementsAttr::get(getVectorType(), results);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
|
||||
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
|
||||
public:
|
||||
using OpRewritePattern<ShuffleOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ShuffleOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
|
||||
auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
|
||||
|
||||
if (!v1Splat || !v2Splat)
|
||||
return failure();
|
||||
|
||||
if (v1Splat.getInput() != v2Splat.getInput())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ShuffleSplat>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1655,3 +1655,17 @@ func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf3
|
|||
: vector<2x4xf32> into vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @shuffle_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32)
|
||||
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
|
||||
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
|
||||
func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
|
||||
%v0 = vector.splat %x : vector<4xi32>
|
||||
%v1 = vector.splat %x : vector<2xi32>
|
||||
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
|
||||
return %shuffle : vector<4xi32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue