From b5ea288d13d099fc60f64932e8826d437e842348 Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 4 Feb 2022 19:13:28 +0000 Subject: [PATCH] [mlir][linalg] Let tile and fuse fail for tile sizes zero. Adapt `tileConsumerAndFuseProducers` to return failure if the generated tile loop nest is empty since all tile sizes are zero. Additionally, fix `LinalgTileAndFuseTensorOpsPattern` to return success if the pattern applied successfully. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D118878 --- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 2 +- .../lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp | 4 ++++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 11 ++++++----- mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index a2f08e79ac47..ae9b8a337f79 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -269,10 +269,10 @@ public: /// Returns the loop ops generated from tiling. ArrayRef getLoopOps() { return tileLoopOps; } -private: /// Returns true if the tile loop nest has no tile loops. bool isEmpty(); +private: /// Returns true if the tile loop nest invariants are satisfied: /// - The `rootOp` has been tiled at least once. /// - The number of tile loop operations and dimensions match. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index eb1415dabde2..154b8f4f26e8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -458,5 +458,9 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); + // Exit if the tile loop nest is empty since all tile sizes are zero. + if (tileLoopNest.isEmpty()) + return failure(); + return tileLoopNest; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index f98666751eef..d8e00ca5d4ed 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -592,10 +592,6 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( SmallVector rootTileSizes(options.tileSizes.begin(), options.tileSizes.begin() + rootOp.getNumLoops()); - if (llvm::all_of(rootTileSizes, [](int64_t ts) { return ts == 0; })) { - return rewriter.notifyMatchFailure( - op, "all tile sizes are zero, nothing to do"); - } SmallVector rootInterchange = options.tileInterchange.empty() ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) @@ -603,6 +599,11 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( options.tileInterchange.begin() + rootOp.getNumLoops()); + // Check `rootTileSizes` contains non-zero tile sizes. + if (llvm::count(rootTileSizes, 0) == static_cast(rootTileSizes.size())) + return rewriter.notifyMatchFailure( + op, "expect at least one non-zero tile size"); + // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. // It has to be a permutation since the tiling cannot tile the same loop // dimension multiple times. @@ -623,7 +624,7 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( // Apply the filter if specified. for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - return failure(); + return success(); } /// Linalg generic interchange pattern. diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir index 80fc9bb4c596..516b3701f38f 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=0,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck %s builtin.func @no_fuse_gemm(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index