[mlir][vector][NFC] split TransposeOp lowerning out of contractLowering

Move TransposeOp lowering in its own populate function as in some cases
it is better to keep it during ContractOp lowering to better
canonicalize it rather than emiting scalar insert/extract.

Differential Revision: https://reviews.llvm.org/D101647
This commit is contained in:
thomasraoux 2021-05-03 10:04:12 -07:00
parent 9779b664b6
commit be8e2801a4
5 changed files with 17 additions and 3 deletions

View File

@ -173,7 +173,6 @@ struct VectorTransformsOptions {
/// ShapeCastOp2DDownCastRewritePattern,
/// ShapeCastOp2DUpCastRewritePattern
/// BroadcastOpLowering,
/// TransposeOpLowering
/// OuterproductOpLowering
/// These transformation express higher level vector ops in terms of more
/// elementary extraction, insertion, reduction, product, and broadcast ops.
@ -181,6 +180,11 @@ void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

View File

@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorSlicesLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

View File

@ -3823,13 +3823,19 @@ void mlir::vector::populateVectorContractLoweringPatterns(
ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern,
ShapeCastOpRewritePattern>(patterns.getContext());
patterns.add<TransposeOpLowering,
ContractionOpLowering,
patterns.add<ContractionOpLowering,
ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
// clang-format on
}
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions vectorTransformOptions) {
patterns.add<TransposeOpLowering>(vectorTransformOptions,
patterns.getContext());
}
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns) {
patterns

View File

@ -109,6 +109,8 @@ void TestConvVectorization::runOnOperation() {
RewritePatternSet vectorContractLoweringPatterns(context);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformsOptions);
populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));

View File

@ -140,6 +140,7 @@ struct TestVectorContractionConversion
transposeLowering = VectorTransposeLowering::Flat;
VectorTransformsOptions options{contractLowering, transposeLowering};
populateVectorContractLoweringPatterns(patterns, options);
populateVectorTransposeLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};