forked from OSchip/llvm-project
[mlir]Linalg] Allow controlling fusion of linalg.generic -> linalg.tensor_expand_shape.
Differential Revision: https://reviews.llvm.org/D108565
This commit is contained in:
parent
e42ce422a9
commit
b546f4347b
|
@ -1133,7 +1133,13 @@ struct FoldConsumerReshapeOpByLinearization
|
|||
/// by expanding the dimensionality of the loop in the producer op.
|
||||
struct FoldReshapeWithGenericOpByExpansion
|
||||
: public OpRewritePattern<TensorExpandShapeOp> {
|
||||
using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
|
||||
|
||||
FoldReshapeWithGenericOpByExpansion(
|
||||
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<TensorExpandShapeOp>(context, benefit),
|
||||
controlFoldingReshapes(foldReshapes) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Fold only if all constraints of fusing with reshape by expansion are met.
|
||||
|
@ -1141,7 +1147,8 @@ struct FoldReshapeWithGenericOpByExpansion
|
|||
if (!producer || producer.getNumOutputs() != 1 ||
|
||||
!isFusableWithReshapeByDimExpansion(producer,
|
||||
producer.getOutputOperand(0)) ||
|
||||
isUnitDimExpansionOnly(reshapeOp))
|
||||
!controlFoldingReshapes(producer->getResult(0),
|
||||
reshapeOp->getOpOperand(0)))
|
||||
return failure();
|
||||
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
|
||||
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
|
||||
|
@ -1150,6 +1157,9 @@ struct FoldReshapeWithGenericOpByExpansion
|
|||
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||
};
|
||||
|
||||
/// Pattern to fold a generic op with a splat constant.
|
||||
|
@ -1242,12 +1252,15 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
|
|||
|
||||
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
||||
OpOperand &consumer) {
|
||||
auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
|
||||
if (expandShapeOp)
|
||||
return !isUnitDimExpansionOnly(expandShapeOp);
|
||||
auto collapseShapeOp =
|
||||
producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
|
||||
return !isUnitDimExpansionOnly(collapseShapeOp);
|
||||
if (auto producerCollapseOp =
|
||||
dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(producerCollapseOp);
|
||||
}
|
||||
if (auto consumerExpandOp =
|
||||
dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
|
||||
return !isUnitDimExpansionOnly(consumerExpandOp);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -1389,7 +1402,8 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|||
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
|
||||
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
|
||||
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
|
||||
controlFoldingReshapes);
|
||||
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
|
||||
controlFoldingReshapes);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
|
||||
|
||||
func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
%d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
|
||||
%d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
|
||||
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
|
||||
outs(%init : tensor<?x?xf32>) {
|
||||
^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32):
|
||||
%2 = addf %arg2, %arg3 : f32
|
||||
linalg.yield %2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
|
||||
// CHECK: builtin.func @control_producer_reshape_fusion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
|
||||
// CHECK-SAME: {{\[}}[0, 1], [2]{{\]}} : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
|
||||
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<?x?xf32>, tensor<?xf32>)
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%cst = constant 0.0 : f32
|
||||
%d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32>
|
||||
%d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32>
|
||||
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%fill = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
outs(%init : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32):
|
||||
linalg.yield %cst : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%0 = linalg.tensor_expand_shape %fill [[0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
|
||||
%1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
|
||||
outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
|
||||
return %1 : tensor<1x?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)
|
||||
// CHECK: builtin.func @control_consumer_reshape_fusion
|
||||
// CHECK: %[[FILL:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP]]]
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<1x?x?xf32>)
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x?x?xf32>)
|
|
@ -73,6 +73,52 @@ struct TestLinalgElementwiseFusion
|
|||
}
|
||||
};
|
||||
|
||||
struct TestLinalgControlFuseByExpansion
|
||||
: public PassWrapper<TestLinalgControlFuseByExpansion, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
}
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-control-fusion-by-expansion";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test controlling of fusion of elementwise ops with reshape by "
|
||||
"expansion";
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
FuncOp funcOp = this->getFunction();
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
|
||||
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
|
||||
[](const OpResult &producer, OpOperand &consumer) {
|
||||
if (auto collapseOp =
|
||||
producer.getDefiningOp<linalg::TensorCollapseShapeOp>()) {
|
||||
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (auto expandOp =
|
||||
dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
|
||||
if (expandOp->hasOneUse()) {
|
||||
OpOperand &use = *expandOp->getUses().begin();
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
|
||||
if (linalgOp && linalgOp.isOutputTensor(&use))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return linalg::skipUnitDimReshape(producer, consumer);
|
||||
};
|
||||
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
||||
controlReshapeFusionFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestPushExpandingReshape
|
||||
: public PassWrapper<TestPushExpandingReshape, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -99,6 +145,10 @@ void registerTestLinalgElementwiseFusion() {
|
|||
PassRegistration<TestLinalgElementwiseFusion>();
|
||||
}
|
||||
|
||||
void registerTestLinalgControlFuseByExpansion() {
|
||||
PassRegistration<TestLinalgControlFuseByExpansion>();
|
||||
}
|
||||
|
||||
void registerTestPushExpandingReshape() {
|
||||
PassRegistration<TestPushExpandingReshape>();
|
||||
}
|
||||
|
|
|
@ -78,6 +78,7 @@ void registerTestGpuParallelLoopMappingPass();
|
|||
void registerTestIRVisitorsPass();
|
||||
void registerTestInterfaces();
|
||||
void registerTestLinalgCodegenStrategy();
|
||||
void registerTestLinalgControlFuseByExpansion();
|
||||
void registerTestLinalgDistribution();
|
||||
void registerTestLinalgElementwiseFusion();
|
||||
void registerTestPushExpandingReshape();
|
||||
|
@ -165,6 +166,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestIRVisitorsPass();
|
||||
mlir::test::registerTestInterfaces();
|
||||
mlir::test::registerTestLinalgCodegenStrategy();
|
||||
mlir::test::registerTestLinalgControlFuseByExpansion();
|
||||
mlir::test::registerTestLinalgDistribution();
|
||||
mlir::test::registerTestLinalgElementwiseFusion();
|
||||
mlir::test::registerTestPushExpandingReshape();
|
||||
|
|
Loading…
Reference in New Issue