[mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.

The moved `populate` methods are only relevant to Linalg
operations. So they are better of in `linalg` namespace.  Also rename
`populateLinalgTensorOpsFusionPatterns` to
`populateElementwiseOpsFusionPatterns`. This makes the scope of these
patterns explicit and disambiguates it with fusion on tensors using
tile + fuse.

Differential Revision: https://reviews.llvm.org/D99819
This commit is contained in:
MaheshRavishankar 2021-04-05 10:54:59 -07:00
parent dc1a08caef
commit ea069aebcc
6 changed files with 64 additions and 69 deletions

View File

@ -50,10 +50,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// buffers instead.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
/// Create a pass to conver named Linalg operations to Linalg generic
/// operations.
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
@ -62,35 +58,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation. The patterns are applied only when
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
void populateLinalgTensorOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@ -36,10 +36,43 @@ void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation. The patterns are applied only when
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify

View File

@ -136,11 +136,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
OpResult producerOpResult,
OpOperand &consumerOpOperand);
/// Fuse linalg operation on tensors, with the producer of the operand at
/// position `consumerIdx` of the consumer.
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
OpOperand &consumerOpOperand);
//===----------------------------------------------------------------------===//
// Distribution utilities
//===----------------------------------------------------------------------===//

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
@ -556,7 +557,7 @@ struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
@ -580,7 +581,7 @@ struct LinalgFoldUnitExtentDimsPass
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
context);
else
populateLinalgFoldUnitExtentDimsPatterns(patterns);
populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};

View File

@ -10,6 +10,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
@ -115,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
};
} // namespace
void mlir::populateElementwiseToLinalgConversionPatterns(
void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext());
@ -131,7 +132,7 @@ class ConvertElementwiseToLinalgPass
ConversionTarget target(*context);
RewritePatternSet patterns(context);
populateElementwiseToLinalgConversionPatterns(patterns);
mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});

View File

@ -26,7 +26,7 @@ using namespace mlir;
using namespace mlir::linalg;
/// Implementation of fusion of generic ops and indexed_generic ops.
static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) {
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
@ -91,9 +91,9 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
Operation *fusedOp, LinalgOp producer,
LinalgOp consumer,
static void
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
LinalgOp producer, LinalgOp consumer,
AffineMap consumerToProducerLoopsMap,
unsigned consumerIdx, unsigned nloops) {
// Build the region of the fused op.
@ -208,11 +208,11 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
}
static Optional<SmallVector<Value, 1>>
fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
PatternRewriter &rewriter) {
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
if (!areTensorOpsFusable(producer, consumer, consumerIdx))
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
return llvm::None;
unsigned numFusedOperands =
@ -291,9 +291,9 @@ fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
consumer, consumerToProducerLoopsMap, consumerIdx,
consumer.getNumLoops());
generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
consumer, consumerToProducerLoopsMap,
consumerIdx, consumer.getNumLoops());
return SmallVector<Value, 1>(fusedOp->getResults());
}
@ -1102,9 +1102,8 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
};
} // namespace
Optional<SmallVector<Value, 1>>
mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
OpOperand &consumerOpOperand) {
static Optional<SmallVector<Value, 1>>
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
Operation *producer = consumerOpOperand.get().getDefiningOp();
if (!producer || producer->getNumResults() != 1)
return llvm::None;
@ -1114,14 +1113,14 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
!isa<GenericOp, IndexedGenericOp>(producer))
return llvm::None;
return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
rewriter);
}
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
template <typename LinalgOpTy>
struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(LinalgOpTy op,
@ -1133,7 +1132,7 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
if (!producerOp || !producerOp.hasTensorSemantics())
continue;
Optional<SmallVector<Value, 1>> fusedOpResults =
fuseTensorOps(rewriter, opOperand);
fuseElementwiseOps(rewriter, opOperand);
if (fusedOpResults) {
rewriter.replaceOp(op, *fusedOpResults);
return success();
@ -1149,8 +1148,7 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgTensorOpsFusionPatterns(patterns,
allowFoldingUnitDimReshapes);
populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@ -1170,7 +1168,7 @@ struct FoldReshapeOpsByLinearizationPass
} // namespace
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
@ -1178,7 +1176,7 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
patterns.getContext());
}
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
@ -1186,7 +1184,7 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
patterns.getContext());
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
@ -1194,11 +1192,11 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), allowFoldingUnitDimReshapes);
}
void mlir::populateLinalgTensorOpsFusionPatterns(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
auto *context = patterns.getContext();
patterns
.add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context);
populateFoldReshapeOpsByExpansionPatterns(patterns,