diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 1d46657018b3..3ce6570f4bf5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -349,6 +349,12 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, consumerOp->getBlock() != rootOp->getBlock()) return failure(); + // Check `consumerOpOperand` is not shape-only to avoid fusion if the data is + // not used by the `consumerOp` computation. + BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand); + if (bbArg.getUses().empty()) + return failure(); + // Check if the producer is a LinalgOp possibly passed by iteration argument. OpOperand *iterArg = nullptr; auto producerResult = sliceOp.source().dyn_cast(); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir index 85c6cca7e366..7509fdb866e4 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir @@ -1,19 +1,40 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.elemwise_unary fuse tile-sizes=32,32,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=UNARY %s -func.func @no_fuse_gemm(%arg0 : tensor, %arg1 : tensor) -> tensor { +// MATMUL-LABEL: @tile_sizes_zero( +func.func @tile_sizes_zero(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %cst = arith.constant 0.0 : f32 %d0 = tensor.dim %arg0, %c0 : tensor %d1 = tensor.dim %arg1, %c1 : tensor %init = linalg.init_tensor [%d0, %d1] : tensor + + // MATMUL-NOT: scf.for + // MATMUL: linalg.fill %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + + // MATMUL-NOT: scf.for + // MATMUL: linalg.matmul %result = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor - return %result : tensor + func.return %result : tensor +} + +// ----- + +// UNARY_LABEL: @shape_only( +func.func @shape_only(%arg0 : tensor, %arg1 : tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + + // UNARY: linalg.fill + %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor + + // UNARY: scf.for + // UNARY: scf.for + // UNARY-NOT: linalg.fill + // UNARY: linalg.elemwise_unary + %1 = linalg.elemwise_unary {fun = #linalg.unary_fn} + ins(%arg0 : tensor) outs(%0 : tensor) -> tensor + func.return %1 : tensor } -// CHECK-LABEL: @no_fuse_gemm( -// CHECK-NOT: scf.for -// CHECK: linalg.fill -// CHECK-NOT: scf.for -// CHECK: linalg.matmul