[mlir]Linalg] Allow controlling fusion of linalg.generic -> linalg.tensor_expand_shape.

Differential Revision: https://reviews.llvm.org/D108565
This commit is contained in:
MaheshRavishankar 2021-08-23 16:27:15 -07:00
parent e42ce422a9
commit b546f4347b
4 changed files with 137 additions and 9 deletions

View File

@ -1133,7 +1133,13 @@ struct FoldConsumerReshapeOpByLinearization
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
: public OpRewritePattern<TensorExpandShapeOp> {
using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
FoldReshapeWithGenericOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<TensorExpandShapeOp>(context, benefit),
controlFoldingReshapes(foldReshapes) {}
LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
@ -1141,7 +1147,8 @@ struct FoldReshapeWithGenericOpByExpansion
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getOutputOperand(0)) ||
isUnitDimExpansionOnly(reshapeOp))
!controlFoldingReshapes(producer->getResult(0),
reshapeOp->getOpOperand(0)))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
@ -1150,6 +1157,9 @@ struct FoldReshapeWithGenericOpByExpansion
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
return success();
}
private:
ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold a generic op with a splat constant.
@ -1242,12 +1252,15 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
OpOperand &consumer) {
auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
if (expandShapeOp)
return !isUnitDimExpansionOnly(expandShapeOp);
auto collapseShapeOp =
producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
return !isUnitDimExpansionOnly(collapseShapeOp);
if (auto producerCollapseOp =
dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) {
return !isUnitDimExpansionOnly(producerCollapseOp);
}
if (auto consumerExpandOp =
dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
return !isUnitDimExpansionOnly(consumerExpandOp);
}
return true;
}
namespace {
@ -1389,7 +1402,8 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}

View File

@ -0,0 +1,62 @@
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
%d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
outs(%init : tensor<?x?xf32>) {
^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32):
%2 = addf %arg2, %arg3 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
// CHECK: builtin.func @control_producer_reshape_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
// CHECK-SAME: {{\[}}[0, 1], [2]{{\]}} : tensor<?x?x?xf32> into tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<?x?xf32>, tensor<?xf32>)
// CHECK: return %[[RESULT]]
// -----
func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
%c1 = constant 1 : index
%c2 = constant 2 : index
%cst = constant 0.0 : f32
%d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32>
%d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%fill = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
outs(%init : tensor<?x?xf32>) {
^bb0(%arg2: f32):
linalg.yield %cst : f32
} -> tensor<?x?xf32>
%0 = linalg.tensor_expand_shape %fill [[0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
return %1 : tensor<1x?x?xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)
// CHECK: builtin.func @control_consumer_reshape_fusion
// CHECK: %[[FILL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]]]
// CHECK-SAME: outs(%{{.+}} : tensor<1x?x?xf32>)
// CHECK: linalg.batch_matmul
// CHECK-SAME: outs(%[[FILL]] : tensor<1x?x?xf32>)

View File

@ -73,6 +73,52 @@ struct TestLinalgElementwiseFusion
}
};
struct TestLinalgControlFuseByExpansion
: public PassWrapper<TestLinalgControlFuseByExpansion, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-control-fusion-by-expansion";
}
StringRef getDescription() const final {
return "Test controlling of fusion of elementwise ops with reshape by "
"expansion";
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
RewritePatternSet fusionPatterns(context);
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
[](const OpResult &producer, OpOperand &consumer) {
if (auto collapseOp =
producer.getDefiningOp<linalg::TensorCollapseShapeOp>()) {
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
return false;
}
}
if (auto expandOp =
dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
}
}
return linalg::skipUnitDimReshape(producer, consumer);
};
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
}
};
struct TestPushExpandingReshape
: public PassWrapper<TestPushExpandingReshape, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
@ -99,6 +145,10 @@ void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion>();
}
void registerTestLinalgControlFuseByExpansion() {
PassRegistration<TestLinalgControlFuseByExpansion>();
}
void registerTestPushExpandingReshape() {
PassRegistration<TestPushExpandingReshape>();
}

View File

@ -78,6 +78,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgControlFuseByExpansion();
void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
@ -165,6 +166,7 @@ void registerTestPasses() {
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
mlir::test::registerTestLinalgControlFuseByExpansion();
mlir::test::registerTestLinalgDistribution();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestPushExpandingReshape();