forked from OSchip/llvm-project
[mlir][Linalg] NFC: Move populatePatterns* method into linalg namespace.
The moved `populate` methods are only relevant to Linalg operations. So they are better of in `linalg` namespace. Also rename `populateLinalgTensorOpsFusionPatterns` to `populateElementwiseOpsFusionPatterns`. This makes the scope of these patterns explicit and disambiguates it with fusion on tensors using tile + fuse. Differential Revision: https://reviews.llvm.org/D99819
This commit is contained in:
parent
dc1a08caef
commit
ea069aebcc
|
@ -50,10 +50,6 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
|
||||||
/// buffers instead.
|
/// buffers instead.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
|
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
|
||||||
|
|
||||||
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
|
|
||||||
/// parallel loops.
|
|
||||||
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
/// Create a pass to conver named Linalg operations to Linalg generic
|
/// Create a pass to conver named Linalg operations to Linalg generic
|
||||||
/// operations.
|
/// operations.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
|
||||||
|
@ -62,35 +58,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
|
||||||
/// work on primitive types, if possible.
|
/// work on primitive types, if possible.
|
||||||
std::unique_ptr<Pass> createLinalgDetensorizePass();
|
std::unique_ptr<Pass> createLinalgDetensorizePass();
|
||||||
|
|
||||||
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
|
||||||
/// producer (consumer) generic operation by expanding the dimensionality of the
|
|
||||||
/// loop in the generic op.
|
|
||||||
void populateFoldReshapeOpsByExpansionPatterns(
|
|
||||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
|
||||||
|
|
||||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
|
||||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
|
||||||
/// indexing map used to access the source (target) of the reshape operation in
|
|
||||||
/// the generic/indexed_generic operation.
|
|
||||||
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
|
||||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
|
||||||
/// indexing map used to access the source (target) of the reshape operation in
|
|
||||||
/// the generic/indexed_generic operation. The patterns are applied only when
|
|
||||||
/// the tensor reshape involved is collapsing (introducing) unit-extent
|
|
||||||
/// dimensions.
|
|
||||||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|
||||||
RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
/// Patterns for fusing linalg operation on tensors.
|
|
||||||
void populateLinalgTensorOpsFusionPatterns(
|
|
||||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
|
||||||
|
|
||||||
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
|
|
||||||
/// tensors.
|
|
||||||
void populateLinalgFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Registration
|
// Registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -36,10 +36,43 @@ void populateConvVectorizationPatterns(
|
||||||
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
|
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
|
||||||
ArrayRef<int64_t> tileSizes);
|
ArrayRef<int64_t> tileSizes);
|
||||||
|
|
||||||
|
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
|
||||||
|
/// parallel loops.
|
||||||
|
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
||||||
|
/// producer (consumer) generic operation by expanding the dimensionality of the
|
||||||
|
/// loop in the generic op.
|
||||||
|
void populateFoldReshapeOpsByExpansionPatterns(
|
||||||
|
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||||
|
|
||||||
|
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||||
|
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||||
|
/// indexing map used to access the source (target) of the reshape operation in
|
||||||
|
/// the generic/indexed_generic operation.
|
||||||
|
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||||
|
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||||
|
/// indexing map used to access the source (target) of the reshape operation in
|
||||||
|
/// the generic/indexed_generic operation. The patterns are applied only when
|
||||||
|
/// the tensor reshape involved is collapsing (introducing) unit-extent
|
||||||
|
/// dimensions.
|
||||||
|
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||||
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
/// Populates the given list with patterns to bufferize linalg ops.
|
/// Populates the given list with patterns to bufferize linalg ops.
|
||||||
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
|
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
|
||||||
RewritePatternSet &patterns);
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
|
||||||
|
/// tensors.
|
||||||
|
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Patterns for fusing linalg operation on tensors.
|
||||||
|
void populateElementwiseOpsFusionPatterns(
|
||||||
|
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||||
|
|
||||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||||
/// and permute the loop nest according to `interchangeVector`
|
/// and permute the loop nest according to `interchangeVector`
|
||||||
/// The permutation is expressed as a list of integers that specify
|
/// The permutation is expressed as a list of integers that specify
|
||||||
|
|
|
@ -136,11 +136,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
|
||||||
OpResult producerOpResult,
|
OpResult producerOpResult,
|
||||||
OpOperand &consumerOpOperand);
|
OpOperand &consumerOpOperand);
|
||||||
|
|
||||||
/// Fuse linalg operation on tensors, with the producer of the operand at
|
|
||||||
/// position `consumerIdx` of the consumer.
|
|
||||||
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
|
|
||||||
OpOperand &consumerOpOperand);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Distribution utilities
|
// Distribution utilities
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/Linalg/Passes.h"
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
@ -556,7 +557,7 @@ struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
|
||||||
|
|
||||||
/// Patterns that are used to canonicalize the use of unit-extent dims for
|
/// Patterns that are used to canonicalize the use of unit-extent dims for
|
||||||
/// broadcasting.
|
/// broadcasting.
|
||||||
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
|
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
auto *context = patterns.getContext();
|
auto *context = patterns.getContext();
|
||||||
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
|
patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
|
||||||
|
@ -580,7 +581,7 @@ struct LinalgFoldUnitExtentDimsPass
|
||||||
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
|
.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
|
||||||
context);
|
context);
|
||||||
else
|
else
|
||||||
populateLinalgFoldUnitExtentDimsPatterns(patterns);
|
populateFoldUnitExtentDimsPatterns(patterns);
|
||||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
|
||||||
|
@ -115,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::populateElementwiseToLinalgConversionPatterns(
|
void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
|
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
|
@ -131,7 +132,7 @@ class ConvertElementwiseToLinalgPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
populateElementwiseToLinalgConversionPatterns(patterns);
|
mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||||
return !isElementwiseMappableOpOnRankedTensors(op);
|
return !isElementwiseMappableOpOnRankedTensors(op);
|
||||||
});
|
});
|
||||||
|
|
|
@ -26,7 +26,7 @@ using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
/// Implementation of fusion of generic ops and indexed_generic ops.
|
/// Implementation of fusion of generic ops and indexed_generic ops.
|
||||||
static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
|
static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
|
||||||
unsigned consumerIdx) {
|
unsigned consumerIdx) {
|
||||||
// Producer and consumer must have tensor semantics.
|
// Producer and consumer must have tensor semantics.
|
||||||
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
|
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
|
||||||
|
@ -91,9 +91,9 @@ static void getIndexingMapOfProducerOperandsInFusedOp(
|
||||||
|
|
||||||
/// Generate the region of the fused tensor operation. The region of the fused
|
/// Generate the region of the fused tensor operation. The region of the fused
|
||||||
/// op must be empty.
|
/// op must be empty.
|
||||||
static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
|
static void
|
||||||
Operation *fusedOp, LinalgOp producer,
|
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
|
||||||
LinalgOp consumer,
|
LinalgOp producer, LinalgOp consumer,
|
||||||
AffineMap consumerToProducerLoopsMap,
|
AffineMap consumerToProducerLoopsMap,
|
||||||
unsigned consumerIdx, unsigned nloops) {
|
unsigned consumerIdx, unsigned nloops) {
|
||||||
// Build the region of the fused op.
|
// Build the region of the fused op.
|
||||||
|
@ -208,11 +208,11 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
static Optional<SmallVector<Value, 1>>
|
static Optional<SmallVector<Value, 1>>
|
||||||
fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
|
LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||||
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
|
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
|
||||||
if (!areTensorOpsFusable(producer, consumer, consumerIdx))
|
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
unsigned numFusedOperands =
|
unsigned numFusedOperands =
|
||||||
|
@ -291,9 +291,9 @@ fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
|
||||||
AffineMap consumerToProducerLoopsMap =
|
AffineMap consumerToProducerLoopsMap =
|
||||||
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
||||||
|
|
||||||
generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
|
generateFusedElementwiseOpRegion(rewriter, fusedOp.getOperation(), producer,
|
||||||
consumer, consumerToProducerLoopsMap, consumerIdx,
|
consumer, consumerToProducerLoopsMap,
|
||||||
consumer.getNumLoops());
|
consumerIdx, consumer.getNumLoops());
|
||||||
return SmallVector<Value, 1>(fusedOp->getResults());
|
return SmallVector<Value, 1>(fusedOp->getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1102,9 +1102,8 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Optional<SmallVector<Value, 1>>
|
static Optional<SmallVector<Value, 1>>
|
||||||
mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
|
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
|
||||||
OpOperand &consumerOpOperand) {
|
|
||||||
Operation *producer = consumerOpOperand.get().getDefiningOp();
|
Operation *producer = consumerOpOperand.get().getDefiningOp();
|
||||||
if (!producer || producer->getNumResults() != 1)
|
if (!producer || producer->getNumResults() != 1)
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
@ -1114,14 +1113,14 @@ mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
|
||||||
!isa<GenericOp, IndexedGenericOp>(producer))
|
!isa<GenericOp, IndexedGenericOp>(producer))
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
return fuseTensorOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
|
return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
|
||||||
rewriter);
|
rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// Patterns to fuse a generic op, with the producer of its operands.
|
/// Patterns to fuse a generic op, with the producer of its operands.
|
||||||
template <typename LinalgOpTy>
|
template <typename LinalgOpTy>
|
||||||
struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
|
struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
|
||||||
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
|
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(LinalgOpTy op,
|
LogicalResult matchAndRewrite(LinalgOpTy op,
|
||||||
|
@ -1133,7 +1132,7 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
|
||||||
if (!producerOp || !producerOp.hasTensorSemantics())
|
if (!producerOp || !producerOp.hasTensorSemantics())
|
||||||
continue;
|
continue;
|
||||||
Optional<SmallVector<Value, 1>> fusedOpResults =
|
Optional<SmallVector<Value, 1>> fusedOpResults =
|
||||||
fuseTensorOps(rewriter, opOperand);
|
fuseElementwiseOps(rewriter, opOperand);
|
||||||
if (fusedOpResults) {
|
if (fusedOpResults) {
|
||||||
rewriter.replaceOp(op, *fusedOpResults);
|
rewriter.replaceOp(op, *fusedOpResults);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1149,8 +1148,7 @@ struct FusionOfTensorOpsPass
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
Operation *op = getOperation();
|
Operation *op = getOperation();
|
||||||
RewritePatternSet patterns(op->getContext());
|
RewritePatternSet patterns(op->getContext());
|
||||||
populateLinalgTensorOpsFusionPatterns(patterns,
|
populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
|
||||||
allowFoldingUnitDimReshapes);
|
|
||||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1170,7 +1168,7 @@ struct FoldReshapeOpsByLinearizationPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
|
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
|
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
|
||||||
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
|
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
|
||||||
|
@ -1178,7 +1176,7 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
|
patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
|
||||||
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
|
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
|
||||||
|
@ -1186,7 +1184,7 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::populateFoldReshapeOpsByExpansionPatterns(
|
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
||||||
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
|
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
|
||||||
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
|
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
|
||||||
|
@ -1194,11 +1192,11 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns(
|
||||||
patterns.getContext(), allowFoldingUnitDimReshapes);
|
patterns.getContext(), allowFoldingUnitDimReshapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::populateLinalgTensorOpsFusionPatterns(
|
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
||||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
||||||
auto *context = patterns.getContext();
|
auto *context = patterns.getContext();
|
||||||
patterns
|
patterns
|
||||||
.add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
|
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
|
||||||
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
|
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
|
||||||
context);
|
context);
|
||||||
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
||||||
|
|
Loading…
Reference in New Issue