[mlir][linalg] Do not fuse shape-only producers.

This revision introduces a heuristic to stop fusion for shape-only tensors. A shape-only tensor only defines the shape of the consumer computation while the data is not used. Pure producer consumer fusion thus shall not fuse the producer of a shape-only tensor. In particular, since the shape-only tensor will have other uses that actually consume the data.

The revision enables fusion for consumers that have two uses of the same tensor. One as input operand and one as shape-only output operand. In these cases, we want to fuse only the input operand and avoid output fusion via iteration argument.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D120981
This commit is contained in:
gysit 2022-03-24 10:22:28 +00:00
parent ec93b28909
commit 53f7fb0a87
2 changed files with 35 additions and 8 deletions

View File

@ -349,6 +349,12 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
consumerOp->getBlock() != rootOp->getBlock()) consumerOp->getBlock() != rootOp->getBlock())
return failure(); return failure();
// Check `consumerOpOperand` is not shape-only to avoid fusion if the data is
// not used by the `consumerOp` computation.
BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand);
if (bbArg.getUses().empty())
return failure();
// Check if the producer is a LinalgOp possibly passed by iteration argument. // Check if the producer is a LinalgOp possibly passed by iteration argument.
OpOperand *iterArg = nullptr; OpOperand *iterArg = nullptr;
auto producerResult = sliceOp.source().dyn_cast<OpResult>(); auto producerResult = sliceOp.source().dyn_cast<OpResult>();

View File

@ -1,19 +1,40 @@
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck %s // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s
// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.elemwise_unary fuse tile-sizes=32,32,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=UNARY %s
func.func @no_fuse_gemm(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { // MATMUL-LABEL: @tile_sizes_zero(
func.func @tile_sizes_zero(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32 %cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32> %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
// MATMUL-NOT: scf.for
// MATMUL: linalg.fill
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32> %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
// MATMUL-NOT: scf.for
// MATMUL: linalg.matmul
%result = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) %result = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32> outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32> func.return %result : tensor<?x?xf32>
}
// -----
// UNARY_LABEL: @shape_only(
func.func @shape_only(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.0 : f32
// UNARY: linalg.fill
%0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
// UNARY: scf.for
// UNARY: scf.for
// UNARY-NOT: linalg.fill
// UNARY: linalg.elemwise_unary
%1 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
func.return %1 : tensor<?x?xf32>
} }
// CHECK-LABEL: @no_fuse_gemm(
// CHECK-NOT: scf.for
// CHECK: linalg.fill
// CHECK-NOT: scf.for
// CHECK: linalg.matmul