From c584771f54cf94bb396c22f5cca895dd3f23c245 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Tue, 21 Jun 2022 16:56:10 +0000 Subject: [PATCH] Revert "[mlir][TilingInterface] Enable tile and fuse using TilingInterface." This reverts commit ea75511319d9dff8c38c8794c3949c40b63a38d7 due to build failures. --- .../SCF/Transforms/TileUsingInterface.h | 55 +----- .../Dialect/Tensor/Transforms/Transforms.h | 9 - .../mlir/Interfaces/TilingInterface.td | 43 +--- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 56 +----- .../SCF/Transforms/TileUsingInterface.cpp | 156 --------------- .../Dialect/Tensor/Transforms/CMakeLists.txt | 2 - .../SwapExtractSliceWithProducer.cpp | 43 ---- .../tile-and-fuse-using-interface.mlir | 185 ------------------ .../TilingInterface/tile-using-interface.mlir | 2 +- .../TilingInterface/TestTilingInterface.cpp | 124 +++--------- .../llvm-project-overlay/mlir/BUILD.bazel | 2 - 11 files changed, 29 insertions(+), 648 deletions(-) delete mode 100644 mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp delete mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 1f3ee8a5b27f..6e8af767ff8a 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -10,12 +10,9 @@ #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" -#include - namespace mlir { class Operation; class PatternRewriter; @@ -58,7 +55,7 @@ struct SCFTilingResult { SmallVector loops; }; -/// Pattern to tile an op that implements the `TilingInterface` using +/// Pattern to tile an op that implementas the `TilingInterface` using /// `scf.for` for iterating over the tiles. struct TileUsingSCFForOp : public OpInterfaceRewritePattern { /// Construct a generic pattern applied to all TilingInterface ops. @@ -84,56 +81,6 @@ private: SCFTilingOptions options; }; -/// Pattern to tile and fuse a sequence of operations, by tiling the consumer -/// and fusing its producers. Note that this assumes that it is valid to -/// tile+fuse the producer into the innermost tiled loop. Its up to the caller -/// to ensure that the tile sizes provided make this fusion valid. -/// -/// For example, for the following sequence -/// -/// ```mlir -/// %0 = linalg.fill ... -/// %1 = linalg.matmul ... outs(%0 : ...) ... -/// ``` -/// -/// it is legal to fuse the fill with the matmul only if the matmul is tiled -/// along the parallel dimensions and not the reduction dimension, i.e. the tile -/// size for the reduction dimension should be 0. -struct SCFTileAndFuseResult { - SmallVector tiledAndFusedOps; - SmallVector loops; -}; -struct TileConsumerAndFuseProducersUsingSCFForOp - : public OpInterfaceRewritePattern { - - /// Construct a generic pattern applied to all TilingInterface ops. - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// Construct a generic pattern applied to `opName`. - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// This pattern uses the tiling pattern. Instead of using inheritance, use - /// the patterns as private object that is instantiated at the same time as - /// this pattern. - TileUsingSCFForOp tilingPattern; -}; - } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 28c22aecdf31..e6267e9cf02e 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -9,7 +9,6 @@ #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -21,14 +20,6 @@ namespace tensor { void populateSplitPaddingPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); -/// Pattern to swap an `tensor.extract_slice` with its producer when the -/// producer implements the `TilingInterface`. The pattern itself does not -/// provide a mechanism to control where the application happens. With use of -/// transform dialect that control is done within the transform dialect. Other -/// use cases can inherit from this pattern and add necessary controls. -FailureOr replaceExtractSliceWithTiledProducer( - OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); - } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index f3fdc30168b2..606901375ede 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -120,48 +120,7 @@ def TilingInterface : OpInterface<"TilingInterface"> { /*defaultImplementation=*/[{ return failure(); }] - >, - InterfaceMethod< - /*desc=*/[{ - Method to generate the code that produces a tile of the result. - - Generates the IR that computes the tile of a result of the - operation. The `offsets` and `sizes` describe the tile of - the output required. This is different from - `getTiledImplementation` which generates the tiled - implementation of the operation given a tile of the - iteration space. This method generates a tiled - implementation of the operation based on the tile of the - result required. This method enables fusion by using tile - and fuse. The method returns failure if the operation can't be - tiled to generate the result tile. In practical terms this - implies it cannot be tiled and fused with its consumers. - - - `dest` are the Value into which the result of the tiled - operation is to be inserted into. The type of the `dest` - Values is same as the types returned by - `getDestinationOperands` method. - - `offsets` provides the offset of the tile within the - iteration space - - `sizes` provides the size of the tile. - - `tileDestOperands` specifies whether to also tile `dest` operands - or not. Avoiding tiling `dest` operands can be useful for - composition with various looping container ops. - }], - /*retType=*/"FailureOr", - /*methodName=*/"generateResultTileValue", - /*args=*/(ins - "OpBuilder &":$b, - "unsigned":$resultNumber, - "ValueRange":$dest, - "ArrayRef":$offsets, - "ArrayRef":$sizes, - "bool":$tileDestOperands), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return failure(); - }] > - ]; + ]; } #endif // MLIR_TILINGINTERFACE diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 88b21f15081f..c67097ab3d69 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -30,6 +30,7 @@ template struct LinalgOpTilingInterface : public TilingInterface::ExternalModel, LinalgOpTy> { + /// Return the destination operands. SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { return llvm::cast(op).getOutputOperands(); @@ -46,8 +47,6 @@ struct LinalgOpTilingInterface /// Return the iteration domain range. SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); @@ -130,65 +129,16 @@ struct LinalgOpTilingInterface resultSizes = sliceOp.getMixedSizes(); return success(); } - - FailureOr generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ValueRange dest, - ArrayRef offsets, - ArrayRef sizes, - bool tileDestOperands) const { - auto linalgOp = cast(op); - - // Check that the indexing map used for the output is a projected - // permutation. This could be relaxed with a more general approach that can - // map the offsets and sizes from the result to iteration space tiles - // (filling in full extent for dimensions not used to access the result). - AffineMap indexingMap = - linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber)); - if (!indexingMap.isProjectedPermutation()) { - return op->emitOpError( - "unhandled tiled implementation generation when result is not " - "accessed using a permuted projection"); - } - - auto numLoops = linalgOp.getNumLoops(); - auto tilingInterfaceOp = cast(op); - SmallVector iterationTileOffsets(numLoops), - iterationTileSizes(numLoops); - if (!indexingMap.isPermutation()) { - SmallVector iterationDomain = - tilingInterfaceOp.getIterationDomain(b); - for (auto range : llvm::enumerate(iterationDomain)) { - iterationTileOffsets[range.index()] = range.value().offset; - iterationTileSizes[range.index()] = range.value().size; - } - } - for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) { - unsigned dimPosition = - resultExpr.value().cast().getPosition(); - iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; - iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; - } - - SmallVector tiledOp = tilingInterfaceOp.getTiledImplementation( - b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands); - if (tiledOp.size() != 1) - return op->emitOpError("failed to generate tiled implementation"); - - return tiledOp[0]->getResult(resultNumber); - } }; } // namespace -template -static void registerOne(MLIRContext *ctx) { +template static void registerOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); } /// Variadic helper function. -template -static void registerAll(MLIRContext *ctx) { +template static void registerAll(MLIRContext *ctx) { // FIXME: In c++17 this can be simplified by using 'fold expressions'. (void)std::initializer_list{0, (registerOne(ctx), 0)...}; } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 1bad67f3d7f4..4646abcf3e8d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -42,10 +42,6 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef ts) { return *this; } -//===----------------------------------------------------------------------===// -// TileUsingSCFForOp pattern implementation. -//===----------------------------------------------------------------------===// - /// Generate an empty loop nest that represents the tiled loop nest shell. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. @@ -251,155 +247,3 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite( rewriter.replaceOp(op, tilingResult.loops.front().getResults()); return tilingResult; } - -//===----------------------------------------------------------------------===// -// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. -//===----------------------------------------------------------------------===// - -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - -/// Return the `Value` that is defined by an operation that implements -/// the `TilingInterface`. Looks through `iter_args` of scf.for nest -/// if required. -static Optional getFusableProducer(Value v) { - while (auto blockArg = v.dyn_cast()) { - auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!loopOp) - return llvm::None; - v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); - } - if (!isa_and_nonnull(v.getDefiningOp())) - return llvm::None; - return v.cast(); -} - -FailureOr -scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { - // This transformation is only valid for ops that return values (i.e. not - // valid to use with operations that have memref operands). - if (!op->getNumResults()) { - return rewriter.notifyMatchFailure( - op, "invalid pattern for op with no results"); - } - - // 1. First tile the consumer. - SCFTileAndFuseResult tileAndFuseResult; - { - FailureOr tilingResult = - tilingPattern.returningMatchAndRewrite(op, rewriter); - if (failed(tilingResult)) { - return failure(); - } - tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); - tileAndFuseResult.loops = std::move(tilingResult->loops); - } - - // 2. Typically, the operands of the tiled operation are slices of the - // operands of the untiled operation. These are expressed in IR using - // `tensor.extract_slice` operations with source being the operands of the - // untiled operation. Create a worklist of these `tensor.extract_slice` - // operations. If the producers of the source of the `tensor.extract_slice` - // can be tiled such that the tiled value is generated in-place, that - // effectively tiles + fuses the operations. - auto addCandidateSlices = [](Operation *fusedOp, - std::deque &candidates) { - for (Value operand : fusedOp->getOperands()) - if (auto sliceOp = operand.getDefiningOp()) - candidates.push_back(sliceOp); - }; - - std::deque candidates; - addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); - OpBuilder::InsertionGuard g(rewriter); - while (!candidates.empty()) { - // 2a. Traverse the slices in BFS fashion. - tensor::ExtractSliceOp candidateSliceOp = candidates.front(); - candidates.pop_front(); - - // 2b. Get the producer of the source (potentially walking through - // `iter_args` of nested `scf.for`) - Optional fusableProducer = - getFusableProducer(candidateSliceOp.source()); - if (!fusableProducer) - continue; - - // 2c. Generate the tiled implementation of the producer of the source - rewriter.setInsertionPoint(candidateSliceOp); - FailureOr fusedProducerValue = - tensor::replaceExtractSliceWithTiledProducer( - rewriter, candidateSliceOp, fusableProducer.getValue()); - if (failed(fusedProducerValue)) - continue; - rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue()); - - // 2d. The operands of the fused producer might themselved be slices of - // values produced by operations that implement the `TilingInterface`. - // Add these operations to the worklist. - Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); - addCandidateSlices(fusedProducer, candidates); - - // 2e. If the operation being fused creates a value that is used as `outs` - // in the tiled operation, the result of the unfused operation will be - // used in the `iter_args` of the tiled loop generated. When the - // operation is fused, this use in `iter_args` needs to be modified to - // use the destination of the fused operation. For example, starting - // with - // - // ```mlir - // %0 = linalg.init_tensor ... - // %1 = linalg.fill ... outs(%0:...)... - // %2 = linalg.matmul ... outs(%1:...).... - // ``` - // - // First the `linalg.matmul` gets tiled - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = linalg.fill - // %2 = scf.for .... iter_args(%arg0 = %1)... - // ... - // ... = linalg.matmul ... - // - // ``` - // - // When the `linalg.fill` gets fused, the `iter_args` needs to be - // modified - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = scf.for ... iter_args(%arg0 = %0)... - // ... - // %2 = linalg.fill ... - // %3 = linalg.matmul ... outs(%2: ...)... - // ``` - TilingInterface unfusedProducerOp = - cast(fusableProducer->getOwner()); - scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); - SmallVector unfusedProducerOpDestValues = - unfusedProducerOp.getDestinationOperands(rewriter); - for (OpOperand &uses : unfusedProducerOp->getUses()) { - if (uses.getOwner() == outerMostTiledLoop.getOperation()) { - unsigned resultNumber = uses.get().cast().getResultNumber(); - unsigned operandNumber = uses.getOperandNumber(); - outerMostTiledLoop->setOperand( - operandNumber, unfusedProducerOpDestValues[resultNumber]); - } - } - } - return tileAndFuseResult; -} diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index 8479c43211e8..f4983c4d5c88 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp SplitPadding.cpp - SwapExtractSliceWithProducer.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms @@ -19,6 +18,5 @@ add_mlir_dialect_library(MLIRTensorTransforms MLIRPass MLIRSCFDialect MLIRTensorDialect - MLIRTilingInterface MLIRTransforms ) diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp deleted file mode 100644 index 8d570cfdf759..000000000000 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp +++ /dev/null @@ -1,43 +0,0 @@ -//===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Swap a `tensor.extract_slice` with the producer of the source if the producer -// implements the `TilingInterface`. When used in conjunction with tiling this -// effectively tiles + fuses the producer with its consumer. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Interfaces/TilingInterface.h" - -using namespace mlir; - -FailureOr tensor::replaceExtractSliceWithTiledProducer( - OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { - auto producerOp = dyn_cast(producer.getOwner()); - if (!producerOp) - return failure(); - - // `TilingInterface` currently only supports strides being 1. - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) - return failure(); - - FailureOr tiledResult = producerOp.generateResultTileValue( - builder, producer.getResultNumber(), - producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), true); - if (failed(tiledResult)) - return failure(); - - return tiledResult.getValue(); -} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir deleted file mode 100644 index dd77211d8ccc..000000000000 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ /dev/null @@ -1,185 +0,0 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s - -func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg1, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"} - ins(%arg0, %arg1 : tensor, tensor) - outs(%fill : tensor) -> tensor - return %gemm : tensor -} -// CHECK: func.func @gemm_fill_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) -// CHECK: %[[INIT:.+]] = linalg.init_tensor -// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) -// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) -// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] -// CHECK: %[[FILL_TILE:.+]] = linalg.fill -// CHECK-SAME: outs(%[[INIT_TILE]] : -// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : -// CHECK-SAME: outs(%[[FILL_TILE]] : -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] - -// ----- - -func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg1, %c1 : tensor - %init = linalg.init_tensor [%d0, %d1] : tensor - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %gemm = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor) - outs(%fill : tensor) -> tensor - %generic = linalg.generic { - __internal_linalg_transform__ = "fusion", - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%gemm, %arg2 : tensor, tensor) outs(%init : tensor) { - ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): - %add = arith.addf %b0, %b1 : f32 - linalg.yield %add : f32 - } -> tensor - return %generic : tensor -} -// CHECK: func.func @gemm_generic_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor, -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) -// CHECK: %[[INIT:.+]] = linalg.init_tensor -// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]]) -// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) -// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] -// CHECK: %[[FILL_TILE:.+]] = linalg.fill -// CHECK-SAME: outs(%[[INIT_TILE]] : -// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : -// CHECK-SAME: outs(%[[FILL_TILE]] : -// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]] -// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic -// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : -// CHECK-SAME: outs(%[[OUTS_TILE]] : -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] - -// ----- - -func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %d0 = tensor.dim %lhs0, %c0 : tensor - %d1 = tensor.dim %rhs0, %c1 : tensor - %init0 = linalg.init_tensor [%d0, %d1] : tensor - %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor - %gemm0 = linalg.matmul - ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor - %d2 = tensor.dim %rhs1, %c1 : tensor - %init1 = linalg.init_tensor [%d0, %d2] : tensor - %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor - %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} - ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor - return %gemm1 : tensor -} -// CHECK: func.func @gemm_gemm_fusion( -// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, -// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] -// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] -// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]] -// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]]) -// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] -// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] -// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0] -// CHECK: %[[FILL0_TILE:.+]] = linalg.fill -// CHECK-SAME: outs(%[[INIT0_TILE]] : -// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : -// CHECK-SAME: outs(%[[FILL0_TILE]] : -// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] -// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] -// CHECK: %[[FILL1_TILE:.+]] = linalg.fill -// CHECK-SAME: outs(%[[INIT1_TILE]] : -// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : -// CHECK-SAME: outs(%[[FILL1_TILE]] : -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0] -// CHECK scf.yield %[[INSERT]] - -// ----- - -func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg1, %c1 : tensor - %init0 = linalg.init_tensor [%d0, %d1] : tensor - %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor - %gemm = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor) - outs(%fill : tensor) -> tensor - %init1 = linalg.init_tensor [%d1, %d0] : tensor - %transpose = linalg.generic { - __internal_linalg_transform__ = "fusion", - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], - iterator_types = ["parallel", "parallel"]} - ins(%gemm : tensor) outs(%init1 : tensor) { - ^bb0(%b0 : f32, %b1 : f32): - linalg.yield %b0 : f32 - } -> tensor - return %transpose : tensor -} -// CHECK: func.func @gemm_transpose_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] -// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] -// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]]) -// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = -// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) -// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]] -// CHECK: %[[FILL_TILE:.+]] = linalg.fill -// CHECK-SAME: outs(%[[INIT0_TILE]] : -// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : -// CHECK-SAME: outs(%[[FILL_TILE]] : -// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic -// CHECK-SAME: ins(%[[GEMM_TILE]] : -// CHECK-SAME: outs(%[[OUTS_TILE]] : -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir index a7367a713ff4..1e094329db66 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 6241603d6a67..f3ba7a1c5f52 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -29,9 +29,8 @@ using namespace mlir; namespace { -/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using -/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while -/// using a `filter` to avoid recursive application. +/// Construct a generic pattern applied to all TilingInterface ops that verify +/// `filter`. struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { TestTileUsingSCFForOpWithFilter(MLIRContext *context, scf::SCFTilingOptions options, @@ -53,7 +52,8 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - auto tilingResult = returningMatchAndRewrite(op, rewriter); + FailureOr tilingResult = + returningMatchAndRewrite(op, rewriter); if (failed(tilingResult)) { return failure(); } @@ -65,50 +65,6 @@ private: linalg::LinalgTransformationFilter filter; }; -/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern -/// (that tiles and fuses operations using the `TilingInterface` with `scf.for` -/// ops for iterating over the tiles) while using a `filter` to avoid recursive -/// application. -struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter - : public scf::TileConsumerAndFuseProducersUsingSCFForOp { - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - MLIRContext *context, scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, - benefit), - filter(filter) {} - - /// Construct a generic pattern applied to `opName`. - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, - benefit), - filter(filter) {} - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter); - if (failed(tileAndFuseResult)) { - return failure(); - } - filter.replaceLinalgTransformationFilter( - rewriter, tileAndFuseResult->tiledAndFusedOps.front()); - return success(); - } - -private: - linalg::LinalgTransformationFilter filter; -}; - -/// Test pass for testing the use of `TilingInterface`. struct TestTilingInterfacePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) @@ -126,63 +82,29 @@ struct TestTilingInterfacePass return "Test tiling using TilingInterface"; } - Option testTiling{ - *this, "tile-using-scf-for", - llvm::cl::desc( - "Test tiling using TilingInterface with scf.for operations"), - llvm::cl::init(false)}; - - Option testTileConsumerAndFuseProducer{ - *this, "tile-consumer-and-fuse-producer-using-scf-for", - llvm::cl::desc("Test tile and fuse transformation using TilingInterface " - "with scf.for operations"), - llvm::cl::init(false)}; - void runOnOperation() override; - -private: - void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns); }; } // namespace -template -static void -addPatternForTiling(MLIRContext *context, ArrayRef tileSizes, - StringRef filterName, RewritePatternSet &patterns) { - scf::SCFTilingOptions tilingOptions; - tilingOptions.setTileSizes(tileSizes); - linalg::LinalgTransformationFilter filter( - StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, filter); -} - -void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - if (testTiling) { - // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTiling( - context, {10, 20}, "simple_gemm", patterns); - // 2. Tiling M, N and K of `linalg.matmul` on buffers. - addPatternForTiling( - context, {10, 20, 30}, "simple_gemm_memref", patterns); - // 3. Tiling 3D parallel generic op which implements a transpose - addPatternForTiling( - context, {10, 0, 20}, "parallel_generic_transpose", patterns); - // 4. Tiling 2D conv op. - addPatternForTiling( - context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns); - return; - } - if (testTileConsumerAndFuseProducer) { - // 1. Tile and fuse of gemm with bias-add operation. - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, {10, 20}, "fusion", patterns); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, {10}, "gemm_fusion", patterns); - return; - } +static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { + auto addPatternForTiling = [&](ArrayRef tileSizes, + StringRef filterName) { + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); + patterns.add(context, tilingOptions, + filter); + }; + // 1. Tiling M and N dims of `linalg.matmul` on tensors. + addPatternForTiling({10, 20}, "simple_gemm"); + // 2. Tiling M, N and K of `linalg.matmul` on buffers. + addPatternForTiling({10, 20, 30}, "simple_gemm_memref"); + // 3. Tiling 3D parallel generic op which implements a transpose + addPatternForTiling({10, 0, 20}, "parallel_generic_transpose"); + // 4. Tiling 2D conv op. + addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv"); } void TestTilingInterfacePass::runOnOperation() { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index f0813db443a5..8ef778730691 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1881,7 +1881,6 @@ cc_library( ":SCFUtils", ":Support", ":TensorDialect", - ":TensorTransforms", ":TilingInterface", ":Transforms", "//llvm:Support", @@ -5029,7 +5028,6 @@ cc_library( ":SCFDialect", ":TensorDialect", ":TensorPassIncGen", - ":TilingInterface", ":Transforms", "//llvm:Support", ],