forked from OSchip/llvm-project
[mlir][linalg][NFC] Make reshape folding control more fine grain
This expose a lambda control instead of just a boolean to control unit dimension folding. This however gives more control to user to pick a good heuristic. Folding reshapes helps fusion opportunities but may generate sub-optimal generic ops. Differential Revision: https://reviews.llvm.org/D101917
This commit is contained in:
parent
b198b9b897
commit
52525cb20f
|
@ -28,6 +28,10 @@ struct LinalgElementwiseFusionOptions;
|
|||
struct LinalgFusionOptions;
|
||||
struct LinalgTilingOptions;
|
||||
|
||||
/// Default function to control reshape folding. Skips folding unit dimension
|
||||
/// reshapes.
|
||||
bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as function calls.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -42,11 +46,15 @@ void populateConvVectorizationPatterns(
|
|||
/// parallel loops.
|
||||
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
using ControlElementwiseOpsFusionFn =
|
||||
std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
|
||||
|
||||
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
|
||||
/// producer (consumer) generic operation by expanding the dimensionality of the
|
||||
/// loop in the generic op.
|
||||
void populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
|
||||
RewritePatternSet &patterns,
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
|
||||
|
||||
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
|
||||
/// producer (consumer) generic/indexed_generic operation by linearizing the
|
||||
|
@ -71,17 +79,15 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
|
|||
/// tensors.
|
||||
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
using ControlElementwiseOpsFusionFn =
|
||||
std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
|
||||
|
||||
/// Options that control fusion of elementwise operations.
|
||||
struct LinalgElementwiseFusionOptions {
|
||||
/// Enable fusion of reshapes that are introducing unit-dimensions into the
|
||||
/// shape with elementwise operations. By default this is disabled.
|
||||
bool allowFoldingUnitDimReshapes = false;
|
||||
/// Enable fusion of reshapes into the shape with elementwise operations. By
|
||||
/// default it is disabled for unit dimensions reshape.
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape;
|
||||
|
||||
LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
|
||||
allowFoldingUnitDimReshapes = val;
|
||||
LinalgElementwiseFusionOptions &
|
||||
setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) {
|
||||
controlFoldingReshapesFn = std::move(fun);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
|
@ -1164,11 +1164,11 @@ template <typename GenericOpTy>
|
|||
class FoldWithProducerReshapeOpByExpansion
|
||||
: public OpRewritePattern<GenericOpTy> {
|
||||
public:
|
||||
FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
|
||||
bool foldUnitDimReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
FoldWithProducerReshapeOpByExpansion(
|
||||
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<GenericOpTy>(context, benefit),
|
||||
allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
|
||||
controlFoldingReshapes(foldReshapes) {}
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOpTy genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
@ -1178,16 +1178,15 @@ public:
|
|||
operand.value().getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
|
||||
// Fold only if
|
||||
// - The tensor reshape op is folding.
|
||||
// - All constraints of fusing with reshape by expansion are met.
|
||||
if (reshapeOp.getSrcType().getRank() <
|
||||
reshapeOp.getResultType().getRank() ||
|
||||
!isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
|
||||
(!allowFoldingUnitDimReshapes &&
|
||||
isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
|
||||
reshapeOp.getReassociationMaps())))
|
||||
(!controlFoldingReshapes(
|
||||
reshapeOp->getResult(0),
|
||||
linalgOp.getInputOpOperands()[operand.index()])))
|
||||
continue;
|
||||
|
||||
Optional<SmallVector<Value, 1>> replacementValues =
|
||||
|
@ -1202,7 +1201,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
bool allowFoldingUnitDimReshapes;
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||
};
|
||||
|
||||
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
|
||||
|
@ -1394,6 +1393,13 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
|
|||
controlFn, rewriter);
|
||||
}
|
||||
|
||||
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
||||
const OpOperand &consumer) {
|
||||
auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
|
||||
return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
|
||||
reshapeOp.getReassociationMaps());
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Patterns to fuse a generic op, with the producer of its operands.
|
||||
template <typename LinalgOpTy>
|
||||
|
@ -1431,10 +1437,14 @@ struct FusionOfTensorOpsPass
|
|||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
ControlElementwiseOpsFusionFn allowFoldingFn =
|
||||
[](const OpResult &producer, const OpOperand &consumer) {
|
||||
return true;
|
||||
};
|
||||
populateElementwiseOpsFusionPatterns(
|
||||
patterns,
|
||||
LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
|
||||
allowFoldingUnitDimReshapes));
|
||||
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
|
||||
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
@ -1471,11 +1481,12 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|||
}
|
||||
|
||||
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
|
||||
RewritePatternSet &patterns,
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
|
||||
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
|
||||
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
|
||||
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
|
||||
patterns.getContext(), allowFoldingUnitDimReshapes);
|
||||
patterns.getContext(), controlFoldingReshapes);
|
||||
}
|
||||
|
||||
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
||||
|
@ -1485,8 +1496,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
|||
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
|
||||
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
|
||||
context, options.controlElementwiseOpsFusionFn);
|
||||
populateFoldReshapeOpsByExpansionPatterns(
|
||||
patterns, options.allowFoldingUnitDimReshapes);
|
||||
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
||||
options.controlFoldingReshapesFn);
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
GenericOp::getCanonicalizationPatterns(patterns, context);
|
||||
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
|
Loading…
Reference in New Issue