[mlir][linalg][NFC] Make reshape folding control more fine grain

This expose a lambda control instead of just a boolean to control unit
dimension folding.
This however gives more control to user to pick a good heuristic.
Folding reshapes helps fusion opportunities but may generate sub-optimal
generic ops.

Differential Revision: https://reviews.llvm.org/D101917
This commit is contained in:
thomasraoux 2021-05-06 07:28:09 -07:00
parent b198b9b897
commit 52525cb20f
2 changed files with 41 additions and 24 deletions

View File

@ -28,6 +28,10 @@ struct LinalgElementwiseFusionOptions;
struct LinalgFusionOptions;
struct LinalgTilingOptions;
/// Default function to control reshape folding. Skips folding unit dimension
/// reshapes.
bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer);
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
@ -42,11 +46,15 @@ void populateConvVectorizationPatterns(
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
using ControlElementwiseOpsFusionFn =
std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
RewritePatternSet &patterns,
ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
@ -71,17 +79,15 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
using ControlElementwiseOpsFusionFn =
std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
/// Options that control fusion of elementwise operations.
struct LinalgElementwiseFusionOptions {
/// Enable fusion of reshapes that are introducing unit-dimensions into the
/// shape with elementwise operations. By default this is disabled.
bool allowFoldingUnitDimReshapes = false;
/// Enable fusion of reshapes into the shape with elementwise operations. By
/// default it is disabled for unit dimensions reshape.
ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape;
LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
allowFoldingUnitDimReshapes = val;
LinalgElementwiseFusionOptions &
setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) {
controlFoldingReshapesFn = std::move(fun);
return *this;
}

View File

@ -1164,11 +1164,11 @@ template <typename GenericOpTy>
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOpTy> {
public:
FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
bool foldUnitDimReshapes,
PatternBenefit benefit = 1)
FoldWithProducerReshapeOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOpTy>(context, benefit),
allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
controlFoldingReshapes(foldReshapes) {}
LogicalResult matchAndRewrite(GenericOpTy genericOp,
PatternRewriter &rewriter) const override {
@ -1178,16 +1178,15 @@ public:
operand.value().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
(!allowFoldingUnitDimReshapes &&
isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
reshapeOp.getReassociationMaps())))
(!controlFoldingReshapes(
reshapeOp->getResult(0),
linalgOp.getInputOpOperands()[operand.index()])))
continue;
Optional<SmallVector<Value, 1>> replacementValues =
@ -1202,7 +1201,7 @@ public:
}
private:
bool allowFoldingUnitDimReshapes;
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
@ -1394,6 +1393,13 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
controlFn, rewriter);
}
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
const OpOperand &consumer) {
auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
reshapeOp.getReassociationMaps());
}
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
template <typename LinalgOpTy>
@ -1431,10 +1437,14 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
ControlElementwiseOpsFusionFn allowFoldingFn =
[](const OpResult &producer, const OpOperand &consumer) {
return true;
};
populateElementwiseOpsFusionPatterns(
patterns,
LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
allowFoldingUnitDimReshapes));
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@ -1471,11 +1481,12 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
}
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
RewritePatternSet &patterns,
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
patterns.getContext(), allowFoldingUnitDimReshapes);
patterns.getContext(), controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
@ -1485,8 +1496,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context, options.controlElementwiseOpsFusionFn);
populateFoldReshapeOpsByExpansionPatterns(
patterns, options.allowFoldingUnitDimReshapes);
populateFoldReshapeOpsByExpansionPatterns(patterns,
options.controlFoldingReshapesFn);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);