[mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.

The moved `populate` methods are only relevant to Linalg
operations. So they are better of in `linalg` namespace.  Also rename
`populateLinalgTensorOpsFusionPatterns` to
`populateElementwiseOpsFusionPatterns`. This makes the scope of these
patterns explicit and disambiguates it with fusion on tensors using
tile + fuse.

Differential Revision: https://reviews.llvm.org/D99819
This commit is contained in:
MaheshRavishankar 2021-04-05 10:54:59 -07:00
parent dc1a08caef
commit ea069aebcc
6 changed files with 64 additions and 69 deletions

View File

@ -50,10 +50,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// buffers instead. /// buffers instead.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass(); std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
/// Create a pass to conver named Linalg operations to Linalg generic /// Create a pass to conver named Linalg operations to Linalg generic
/// operations. /// operations.
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass(); std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
@ -62,35 +58,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
/// work on primitive types, if possible. /// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass(); std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation. The patterns are applied only when
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
void populateLinalgTensorOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Registration // Registration
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -36,10 +36,43 @@ void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns, MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes); ArrayRef<int64_t> tileSizes);
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation. The patterns are applied only when
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Populates the given list with patterns to bufferize linalg ops. /// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter, void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
RewritePatternSet &patterns); RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector` /// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify /// The permutation is expressed as a list of integers that specify

View File

@ -136,11 +136,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
OpResult producerOpResult, OpResult producerOpResult,
OpOperand &consumerOpOperand); OpOperand &consumerOpOperand);
/// Fuse linalg operation on tensors, with the producer of the operand at
/// position `consumerIdx` of the consumer.
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
OpOperand &consumerOpOperand);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Distribution utilities // Distribution utilities
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
@ -556,7 +557,7 @@ struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
/// Patterns that are used to canonicalize the use of unit-extent dims for /// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting. /// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns( void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
auto *context = patterns.getContext(); auto *context = patterns.getContext();
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>, patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
@ -580,7 +581,7 @@ struct LinalgFoldUnitExtentDimsPass
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>( .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
context); context);
else else
populateLinalgFoldUnitExtentDimsPatterns(patterns); populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
} }
}; };

View File

@ -10,6 +10,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h"
@ -115,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
}; };
} // namespace } // namespace
void mlir::populateElementwiseToLinalgConversionPatterns( void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>( patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext()); patterns.getContext());
@ -131,7 +132,7 @@ class ConvertElementwiseToLinalgPass
ConversionTarget target(*context); ConversionTarget target(*context);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
populateElementwiseToLinalgConversionPatterns(patterns); mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) { target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op); return !isElementwiseMappableOpOnRankedTensors(op);
}); });

View File

@ -26,7 +26,7 @@ using namespace mlir;
using namespace mlir::linalg; using namespace mlir::linalg;
/// Implementation of fusion of generic ops and indexed_generic ops. /// Implementation of fusion of generic ops and indexed_generic ops.
static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer, static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) { unsigned consumerIdx) {
// Producer and consumer must have tensor semantics. // Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
@ -91,9 +91,9 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
/// Generate the region of the fused tensor operation. The region of the fused /// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty. /// op must be empty.
static void generateFusedTensorOpRegion(PatternRewriter &rewriter, static void
Operation *fusedOp, LinalgOp producer, generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
LinalgOp consumer, LinalgOp producer, LinalgOp consumer,
AffineMap consumerToProducerLoopsMap, AffineMap consumerToProducerLoopsMap,
unsigned consumerIdx, unsigned nloops) { unsigned consumerIdx, unsigned nloops) {
// Build the region of the fused op. // Build the region of the fused op.
@ -208,11 +208,11 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
} }
static Optional<SmallVector<Value, 1>> static Optional<SmallVector<Value, 1>>
fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner()); LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
unsigned consumerIdx = consumerOpOperand.getOperandNumber(); unsigned consumerIdx = consumerOpOperand.getOperandNumber();
if (!areTensorOpsFusable(producer, consumer, consumerIdx)) if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
return llvm::None; return llvm::None;
unsigned numFusedOperands = unsigned numFusedOperands =
@ -291,9 +291,9 @@ fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
AffineMap consumerToProducerLoopsMap = AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap); invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer, generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
consumer, consumerToProducerLoopsMap, consumerIdx, consumer, consumerToProducerLoopsMap,
consumer.getNumLoops()); consumerIdx, consumer.getNumLoops());
return SmallVector<Value, 1>(fusedOp->getResults()); return SmallVector<Value, 1>(fusedOp->getResults());
} }
@ -1102,9 +1102,8 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
}; };
} // namespace } // namespace
Optional<SmallVector<Value, 1>> static Optional<SmallVector<Value, 1>>
mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
OpOperand &consumerOpOperand) {
Operation *producer = consumerOpOperand.get().getDefiningOp(); Operation *producer = consumerOpOperand.get().getDefiningOp();
if (!producer || producer->getNumResults() != 1) if (!producer || producer->getNumResults() != 1)
return llvm::None; return llvm::None;
@ -1114,14 +1113,14 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
!isa<GenericOp, IndexedGenericOp>(producer)) !isa<GenericOp, IndexedGenericOp>(producer))
return llvm::None; return llvm::None;
return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand, return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
rewriter); rewriter);
} }
namespace { namespace {
/// Patterns to fuse a generic op, with the producer of its operands. /// Patterns to fuse a generic op, with the producer of its operands.
template <typename LinalgOpTy> template <typename LinalgOpTy>
struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> { struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
using OpRewritePattern<LinalgOpTy>::OpRewritePattern; using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(LinalgOpTy op, LogicalResult matchAndRewrite(LinalgOpTy op,
@ -1133,7 +1132,7 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
if (!producerOp || !producerOp.hasTensorSemantics()) if (!producerOp || !producerOp.hasTensorSemantics())
continue; continue;
Optional<SmallVector<Value, 1>> fusedOpResults = Optional<SmallVector<Value, 1>> fusedOpResults =
fuseTensorOps(rewriter, opOperand); fuseElementwiseOps(rewriter, opOperand);
if (fusedOpResults) { if (fusedOpResults) {
rewriter.replaceOp(op, *fusedOpResults); rewriter.replaceOp(op, *fusedOpResults);
return success(); return success();
@ -1149,8 +1148,7 @@ struct FusionOfTensorOpsPass
void runOnOperation() override { void runOnOperation() override {
Operation *op = getOperation(); Operation *op = getOperation();
RewritePatternSet patterns(op->getContext()); RewritePatternSet patterns(op->getContext());
populateLinalgTensorOpsFusionPatterns(patterns, populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
allowFoldingUnitDimReshapes);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
} }
}; };
@ -1170,7 +1168,7 @@ struct FoldReshapeOpsByLinearizationPass
} // namespace } // namespace
void mlir::populateFoldReshapeOpsByLinearizationPatterns( void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>, patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>, FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
@ -1178,7 +1176,7 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
patterns.getContext()); patterns.getContext());
} }
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>, patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>, FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
@ -1186,7 +1184,7 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
patterns.getContext()); patterns.getContext());
} }
void mlir::populateFoldReshapeOpsByExpansionPatterns( void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext()); patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>, patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
@ -1194,11 +1192,11 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), allowFoldingUnitDimReshapes); patterns.getContext(), allowFoldingUnitDimReshapes);
} }
void mlir::populateLinalgTensorOpsFusionPatterns( void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) { RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
auto *context = patterns.getContext(); auto *context = patterns.getContext();
patterns patterns
.add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>, .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>( FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context); context);
populateFoldReshapeOpsByExpansionPatterns(patterns, populateFoldReshapeOpsByExpansionPatterns(patterns,