forked from OSchip/llvm-project
[mlir][vector][NFC] split TransposeOp lowerning out of contractLowering
Move TransposeOp lowering in its own populate function as in some cases it is better to keep it during ContractOp lowering to better canonicalize it rather than emiting scalar insert/extract. Differential Revision: https://reviews.llvm.org/D101647
This commit is contained in:
parent
9779b664b6
commit
be8e2801a4
|
@ -173,7 +173,6 @@ struct VectorTransformsOptions {
|
|||
/// ShapeCastOp2DDownCastRewritePattern,
|
||||
/// ShapeCastOp2DUpCastRewritePattern
|
||||
/// BroadcastOpLowering,
|
||||
/// TransposeOpLowering
|
||||
/// OuterproductOpLowering
|
||||
/// These transformation express higher level vector ops in terms of more
|
||||
/// elementary extraction, insertion, reduction, product, and broadcast ops.
|
||||
|
@ -181,6 +180,11 @@ void populateVectorContractLoweringPatterns(
|
|||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
|
||||
|
||||
/// Insert TransposeLowering patterns into extraction/insertion.
|
||||
void populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
|
||||
|
||||
/// Returns the integer type required for subscripts in the vector dialect.
|
||||
IntegerType getVectorSubscriptType(Builder &builder);
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
|||
populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
populateVectorSlicesLoweringPatterns(patterns);
|
||||
populateVectorContractLoweringPatterns(patterns);
|
||||
populateVectorTransposeLoweringPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
|
|
|
@ -3823,13 +3823,19 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
|||
ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern,
|
||||
ShapeCastOpRewritePattern>(patterns.getContext());
|
||||
patterns.add<TransposeOpLowering,
|
||||
ContractionOpLowering,
|
||||
patterns.add<ContractionOpLowering,
|
||||
ContractionOpToMatmulOpLowering,
|
||||
ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions vectorTransformOptions) {
|
||||
patterns.add<TransposeOpLowering>(vectorTransformOptions,
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorTransferLoweringPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
|
|
|
@ -109,6 +109,8 @@ void TestConvVectorization::runOnOperation() {
|
|||
RewritePatternSet vectorContractLoweringPatterns(context);
|
||||
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
|
||||
vectorTransformsOptions);
|
||||
populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
|
||||
vectorTransformsOptions);
|
||||
(void)applyPatternsAndFoldGreedily(module,
|
||||
std::move(vectorContractLoweringPatterns));
|
||||
|
||||
|
|
|
@ -140,6 +140,7 @@ struct TestVectorContractionConversion
|
|||
transposeLowering = VectorTransposeLowering::Flat;
|
||||
VectorTransformsOptions options{contractLowering, transposeLowering};
|
||||
populateVectorContractLoweringPatterns(patterns, options);
|
||||
populateVectorTransposeLoweringPatterns(patterns, options);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue