forked from OSchip/llvm-project
[mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.
The moved `populate` methods are only relevant to Linalg operations. So they are better of in `linalg` namespace. Also rename `populateLinalgTensorOpsFusionPatterns` to `populateElementwiseOpsFusionPatterns`. This makes the scope of these patterns explicit and disambiguates it with fusion on tensors using tile + fuse. Differential Revision: https://reviews.llvm.org/D99819
This commit is contained in:
parent
dc1a08caef
commit
ea069aebcc
|
@ -50,10 +50,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
|
|||
/// buffers instead.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
|
||||
|
||||
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
|
||||
/// parallel loops.
|
||||
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Create a pass to conver named Linalg operations to Linalg generic
|
||||
/// operations.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
|
||||
|
@ -62,35 +58,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
|
|||
/// work on primitive types, if possible.
|
||||
std::unique_ptr<Pass> createLinalgDetensorizePass();
|
||||
|
||||
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
||||
/// producer (consumer) generic operation by expanding the dimensionality of the
|
||||
/// loop in the generic op.
|
||||
void populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||
/// indexing map used to access the source (target) of the reshape operation in
|
||||
/// the generic/indexed_generic operation.
|
||||
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||
/// indexing map used to access the source (target) of the reshape operation in
|
||||
/// the generic/indexed_generic operation. The patterns are applied only when
|
||||
/// the tensor reshape involved is collapsing (introducing) unit-extent
|
||||
/// dimensions.
|
||||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns for fusing linalg operation on tensors.
|
||||
void populateLinalgTensorOpsFusionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||
|
||||
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
|
||||
/// tensors.
|
||||
void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -36,10 +36,43 @@ void populateConvVectorizationPatterns(
|
|||
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
|
||||
ArrayRef<int64_t> tileSizes);
|
||||
|
||||
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
|
||||
/// parallel loops.
|
||||
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
||||
/// producer (consumer) generic operation by expanding the dimensionality of the
|
||||
/// loop in the generic op.
|
||||
void populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||
/// indexing map used to access the source (target) of the reshape operation in
|
||||
/// the generic/indexed_generic operation.
|
||||
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||
/// indexing map used to access the source (target) of the reshape operation in
|
||||
/// the generic/indexed_generic operation. The patterns are applied only when
|
||||
/// the tensor reshape involved is collapsing (introducing) unit-extent
|
||||
/// dimensions.
|
||||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Populates the given list with patterns to bufferize linalg ops.
|
||||
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
|
||||
/// tensors.
|
||||
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns for fusing linalg operation on tensors.
|
||||
void populateElementwiseOpsFusionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `interchangeVector`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
|
|
|
@ -136,11 +136,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
|
|||
OpResult producerOpResult,
|
||||
OpOperand &consumerOpOperand);
|
||||
|
||||
/// Fuse linalg operation on tensors, with the producer of the operand at
|
||||
/// position `consumerIdx` of the consumer.
|
||||
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
|
||||
OpOperand &consumerOpOperand);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Distribution utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -556,7 +557,7 @@ struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
|
|||
|
||||
/// Patterns that are used to canonicalize the use of unit-extent dims for
|
||||
/// broadcasting.
|
||||
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
|
||||
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
auto *context = patterns.getContext();
|
||||
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
|
||||
|
@ -580,7 +581,7 @@ struct LinalgFoldUnitExtentDimsPass
|
|||
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
|
||||
context);
|
||||
else
|
||||
populateLinalgFoldUnitExtentDimsPatterns(patterns);
|
||||
populateFoldUnitExtentDimsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||
|
@ -115,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateElementwiseToLinalgConversionPatterns(
|
||||
void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
|
||||
patterns.getContext());
|
||||
|
@ -131,7 +132,7 @@ class ConvertElementwiseToLinalgPass
|
|||
ConversionTarget target(*context);
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
populateElementwiseToLinalgConversionPatterns(patterns);
|
||||
mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
return !isElementwiseMappableOpOnRankedTensors(op);
|
||||
});
|
||||
|
|
|
@ -26,8 +26,8 @@ using namespace mlir;
|
|||
using namespace mlir::linalg;
|
||||
|
||||
/// Implementation of fusion of generic ops and indexed_generic ops.
|
||||
static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
|
||||
unsigned consumerIdx) {
|
||||
static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
|
||||
unsigned consumerIdx) {
|
||||
// Producer and consumer must have tensor semantics.
|
||||
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
|
||||
return false;
|
||||
|
@ -91,11 +91,11 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
|
|||
|
||||
/// Generate the region of the fused tensor operation. The region of the fused
|
||||
/// op must be empty.
|
||||
static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
|
||||
Operation *fusedOp, LinalgOp producer,
|
||||
LinalgOp consumer,
|
||||
AffineMap consumerToProducerLoopsMap,
|
||||
unsigned consumerIdx, unsigned nloops) {
|
||||
static void
|
||||
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
|
||||
LinalgOp producer, LinalgOp consumer,
|
||||
AffineMap consumerToProducerLoopsMap,
|
||||
unsigned consumerIdx, unsigned nloops) {
|
||||
// Build the region of the fused op.
|
||||
Block &producerBlock = producer->getRegion(0).front();
|
||||
Block &consumerBlock = consumer->getRegion(0).front();
|
||||
|
@ -208,11 +208,11 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
static Optional<SmallVector<Value, 1>>
|
||||
fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
||||
PatternRewriter &rewriter) {
|
||||
fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
||||
PatternRewriter &rewriter) {
|
||||
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
|
||||
if (!areTensorOpsFusable(producer, consumer, consumerIdx))
|
||||
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
|
||||
return llvm::None;
|
||||
|
||||
unsigned numFusedOperands =
|
||||
|
@ -291,9 +291,9 @@ fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
|||
AffineMap consumerToProducerLoopsMap =
|
||||
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
||||
|
||||
generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
|
||||
consumer, consumerToProducerLoopsMap, consumerIdx,
|
||||
consumer.getNumLoops());
|
||||
generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
|
||||
consumer, consumerToProducerLoopsMap,
|
||||
consumerIdx, consumer.getNumLoops());
|
||||
return SmallVector<Value, 1>(fusedOp->getResults());
|
||||
}
|
||||
|
||||
|
@ -1102,9 +1102,8 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
Optional<SmallVector<Value, 1>>
|
||||
mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
|
||||
OpOperand &consumerOpOperand) {
|
||||
static Optional<SmallVector<Value, 1>>
|
||||
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
|
||||
Operation *producer = consumerOpOperand.get().getDefiningOp();
|
||||
if (!producer || producer->getNumResults() != 1)
|
||||
return llvm::None;
|
||||
|
@ -1114,14 +1113,14 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
|
|||
!isa<GenericOp, IndexedGenericOp>(producer))
|
||||
return llvm::None;
|
||||
|
||||
return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
|
||||
rewriter);
|
||||
return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Patterns to fuse a generic op, with the producer of its operands.
|
||||
template <typename LinalgOpTy>
|
||||
struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
|
||||
struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
|
||||
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(LinalgOpTy op,
|
||||
|
@ -1133,7 +1132,7 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
|
|||
if (!producerOp || !producerOp.hasTensorSemantics())
|
||||
continue;
|
||||
Optional<SmallVector<Value, 1>> fusedOpResults =
|
||||
fuseTensorOps(rewriter, opOperand);
|
||||
fuseElementwiseOps(rewriter, opOperand);
|
||||
if (fusedOpResults) {
|
||||
rewriter.replaceOp(op, *fusedOpResults);
|
||||
return success();
|
||||
|
@ -1149,8 +1148,7 @@ struct FusionOfTensorOpsPass
|
|||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateLinalgTensorOpsFusionPatterns(patterns,
|
||||
allowFoldingUnitDimReshapes);
|
||||
populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
@ -1170,7 +1168,7 @@ struct FoldReshapeOpsByLinearizationPass
|
|||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
|
||||
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
|
||||
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
|
||||
|
@ -1178,7 +1176,7 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
|
|||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
|
||||
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
|
||||
|
@ -1186,7 +1184,7 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::populateFoldReshapeOpsByExpansionPatterns(
|
||||
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
||||
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
|
||||
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
|
||||
|
@ -1194,11 +1192,11 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns(
|
|||
patterns.getContext(), allowFoldingUnitDimReshapes);
|
||||
}
|
||||
|
||||
void mlir::populateLinalgTensorOpsFusionPatterns(
|
||||
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
||||
auto *context = patterns.getContext();
|
||||
patterns
|
||||
.add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
|
||||
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
|
||||
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
|
||||
context);
|
||||
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
||||
|
|
Loading…
Reference in New Issue