[mlir][TilingInterface] Enable tile and fuse using TilingInterface.

This patch implements tile and fuse transformation for ops that
implement the tiling interface. To do so,
- `TilingInterface` needs a new method that generates a tiled
  implementation of the operation based on the tile of the result
  needed.
- A pattern is added that replaces a `tensor.extract_slice` whose
  source is defined by an operation that implements the
  `TilingInterface` with a tiled implementation that produces the
  extracted slice in-place (using the method added to
  `TilingInterface`).
- A pattern is added that takes a sequence of operations that
  implement the `TilingInterface` (for now `LinalgOp`s), tiles the
  consumer, and greedily fuses its producers iteratively.

Differential Revision: https://reviews.llvm.org/D127809
This commit is contained in:
Mahesh Ravishankar 2022-06-13 23:24:31 +00:00
parent aaf1630ac3
commit ea75511319
11 changed files with 648 additions and 29 deletions

View File

@ -10,9 +10,12 @@
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include <deque>
namespace mlir {
class Operation;
class PatternRewriter;
@ -55,7 +58,7 @@ struct SCFTilingResult {
SmallVector<scf::ForOp> loops;
};
/// Pattern to tile an op that implementas the `TilingInterface` using
/// Pattern to tile an op that implements the `TilingInterface` using
/// `scf.for` for iterating over the tiles.
struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
/// Construct a generic pattern applied to all TilingInterface ops.
@ -81,6 +84,56 @@ private:
SCFTilingOptions options;
};
/// Pattern to tile and fuse a sequence of operations, by tiling the consumer
/// and fusing its producers. Note that this assumes that it is valid to
/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
/// to ensure that the tile sizes provided make this fusion valid.
///
/// For example, for the following sequence
///
/// ```mlir
/// %0 = linalg.fill ...
/// %1 = linalg.matmul ... outs(%0 : ...) ...
/// ```
///
/// it is legal to fuse the fill with the matmul only if the matmul is tiled
/// along the parallel dimensions and not the reduction dimension, i.e. the tile
/// size for the reduction dimension should be 0.
struct SCFTileAndFuseResult {
SmallVector<Operation *> tiledAndFusedOps;
SmallVector<scf::ForOp> loops;
};
struct TileConsumerAndFuseProducersUsingSCFForOp
: public OpInterfaceRewritePattern<TilingInterface> {
/// Construct a generic pattern applied to all TilingInterface ops.
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
SCFTilingOptions options,
PatternBenefit benefit = 1);
/// Construct a generic pattern applied to `opName`.
TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
MLIRContext *context,
SCFTilingOptions options,
PatternBenefit benefit = 1);
/// `matchAndRewrite` implementation that returns the significant transformed
/// pieces of IR.
FailureOr<SCFTileAndFuseResult>
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
private:
/// This pattern uses the tiling pattern. Instead of using inheritance, use
/// the patterns as private object that is instantiated at the same time as
/// this pattern.
TileUsingSCFForOp tilingPattern;
};
} // namespace scf
} // namespace mlir

View File

@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@ -20,6 +21,14 @@ namespace tensor {
void populateSplitPaddingPatterns(RewritePatternSet &patterns,
PatternBenefit baseBenefit = 1);
/// Pattern to swap an `tensor.extract_slice` with its producer when the
/// producer implements the `TilingInterface`. The pattern itself does not
/// provide a mechanism to control where the application happens. With use of
/// transform dialect that control is done within the transform dialect. Other
/// use cases can inherit from this pattern and add necessary controls.
FailureOr<Value> replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
} // namespace tensor
} // namespace mlir

View File

@ -120,7 +120,48 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*defaultImplementation=*/[{
return failure();
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
Generates the IR that computes the tile of a result of the
operation. The `offsets` and `sizes` describe the tile of
the output required. This is different from
`getTiledImplementation` which generates the tiled
implementation of the operation given a tile of the
iteration space. This method generates a tiled
implementation of the operation based on the tile of the
result required. This method enables fusion by using tile
and fuse. The method returns failure if the operation can't be
tiled to generate the result tile. In practical terms this
implies it cannot be tiled and fused with its consumers.
- `dest` are the Value into which the result of the tiled
operation is to be inserted into. The type of the `dest`
Values is same as the types returned by
`getDestinationOperands` method.
- `offsets` provides the offset of the tile within the
iteration space
- `sizes` provides the size of the tile.
- `tileDestOperands` specifies whether to also tile `dest` operands
or not. Avoiding tiling `dest` operands can be useful for
composition with various looping container ops.
}],
/*retType=*/"FailureOr<Value>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ValueRange":$dest,
"ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes,
"bool":$tileDestOperands),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
}]
>
];
];
}
#endif // MLIR_TILINGINTERFACE

View File

@ -30,7 +30,6 @@ template <typename LinalgOpTy>
struct LinalgOpTilingInterface
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
LinalgOpTy> {
/// Return the destination operands.
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
return llvm::cast<LinalgOp>(op).getOutputOperands();
@ -47,6 +46,8 @@ struct LinalgOpTilingInterface
/// Return the iteration domain range.
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
@ -129,16 +130,65 @@ struct LinalgOpTilingInterface
resultSizes = sliceOp.getMixedSizes();
return success();
}
FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
unsigned resultNumber,
ValueRange dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool tileDestOperands) const {
auto linalgOp = cast<LinalgOp>(op);
// Check that the indexing map used for the output is a projected
// permutation. This could be relaxed with a more general approach that can
// map the offsets and sizes from the result to iteration space tiles
// (filling in full extent for dimensions not used to access the result).
AffineMap indexingMap =
linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
if (!indexingMap.isProjectedPermutation()) {
return op->emitOpError(
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
auto numLoops = linalgOp.getNumLoops();
auto tilingInterfaceOp = cast<TilingInterface>(op);
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
iterationTileSizes(numLoops);
if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (auto range : llvm::enumerate(iterationDomain)) {
iterationTileOffsets[range.index()] = range.value().offset;
iterationTileSizes[range.index()] = range.value().size;
}
}
for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) {
unsigned dimPosition =
resultExpr.value().cast<AffineDimExpr>().getPosition();
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
}
SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
if (tiledOp.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
return tiledOp[0]->getResult(resultNumber);
}
};
} // namespace
template <typename OpType> static void registerOne(MLIRContext *ctx) {
template <typename OpType>
static void registerOne(MLIRContext *ctx) {
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
}
/// Variadic helper function.
template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
template <typename... OpTypes>
static void registerAll(MLIRContext *ctx) {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
}

View File

@ -42,6 +42,10 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
return *this;
}
//===----------------------------------------------------------------------===//
// TileUsingSCFForOp pattern implementation.
//===----------------------------------------------------------------------===//
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
@ -247,3 +251,155 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
rewriter.replaceOp(op, tilingResult.loops.front().getResults());
return tilingResult;
}
//===----------------------------------------------------------------------===//
// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
//===----------------------------------------------------------------------===//
scf::TileConsumerAndFuseProducersUsingSCFForOp::
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
scf::SCFTilingOptions options,
PatternBenefit benefit)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
tilingPattern(context, std::move(options)) {}
scf::TileConsumerAndFuseProducersUsingSCFForOp::
TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
MLIRContext *context,
scf::SCFTilingOptions options,
PatternBenefit benefit)
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
tilingPattern(context, std::move(options)) {}
/// Return the `Value` that is defined by an operation that implements
/// the `TilingInterface`. Looks through `iter_args` of scf.for nest
/// if required.
static Optional<OpResult> getFusableProducer(Value v) {
while (auto blockArg = v.dyn_cast<BlockArgument>()) {
auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
if (!loopOp)
return llvm::None;
v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
}
if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
return llvm::None;
return v.cast<OpResult>();
}
FailureOr<scf::SCFTileAndFuseResult>
scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
TilingInterface op, PatternRewriter &rewriter) const {
// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
if (!op->getNumResults()) {
return rewriter.notifyMatchFailure(
op, "invalid pattern for op with no results");
}
// 1. First tile the consumer.
SCFTileAndFuseResult tileAndFuseResult;
{
FailureOr<SCFTilingResult> tilingResult =
tilingPattern.returningMatchAndRewrite(op, rewriter);
if (failed(tilingResult)) {
return failure();
}
tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
tileAndFuseResult.loops = std::move(tilingResult->loops);
}
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
// `tensor.extract_slice` operations with source being the operands of the
// untiled operation. Create a worklist of these `tensor.extract_slice`
// operations. If the producers of the source of the `tensor.extract_slice`
// can be tiled such that the tiled value is generated in-place, that
// effectively tiles + fuses the operations.
auto addCandidateSlices = [](Operation *fusedOp,
std::deque<tensor::ExtractSliceOp> &candidates) {
for (Value operand : fusedOp->getOperands())
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
candidates.push_back(sliceOp);
};
std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
// 2a. Traverse the slices in BFS fashion.
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();
// 2b. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
Optional<OpResult> fusableProducer =
getFusableProducer(candidateSliceOp.source());
if (!fusableProducer)
continue;
// 2c. Generate the tiled implementation of the producer of the source
rewriter.setInsertionPoint(candidateSliceOp);
FailureOr<Value> fusedProducerValue =
tensor::replaceExtractSliceWithTiledProducer(
rewriter, candidateSliceOp, fusableProducer.getValue());
if (failed(fusedProducerValue))
continue;
rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue());
// 2d. The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
Operation *fusedProducer = fusedProducerValue->getDefiningOp();
tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
addCandidateSlices(fusedProducer, candidates);
// 2e. If the operation being fused creates a value that is used as `outs`
// in the tiled operation, the result of the unfused operation will be
// used in the `iter_args` of the tiled loop generated. When the
// operation is fused, this use in `iter_args` needs to be modified to
// use the destination of the fused operation. For example, starting
// with
//
// ```mlir
// %0 = linalg.init_tensor ...
// %1 = linalg.fill ... outs(%0:...)...
// %2 = linalg.matmul ... outs(%1:...)....
// ```
//
// First the `linalg.matmul` gets tiled
//
// ```mlir
// %0 = linalg.init_tensor
// %1 = linalg.fill
// %2 = scf.for .... iter_args(%arg0 = %1)...
// ...
// ... = linalg.matmul ...
//
// ```
//
// When the `linalg.fill` gets fused, the `iter_args` needs to be
// modified
//
// ```mlir
// %0 = linalg.init_tensor
// %1 = scf.for ... iter_args(%arg0 = %0)...
// ...
// %2 = linalg.fill ...
// %3 = linalg.matmul ... outs(%2: ...)...
// ```
TilingInterface unfusedProducerOp =
cast<TilingInterface>(fusableProducer->getOwner());
scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
SmallVector<Value> unfusedProducerOpDestValues =
unfusedProducerOp.getDestinationOperands(rewriter);
for (OpOperand &uses : unfusedProducerOp->getUses()) {
if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
unsigned operandNumber = uses.getOperandNumber();
outerMostTiledLoop->setOperand(
operandNumber, unfusedProducerOpDestValues[resultNumber]);
}
}
}
return tileAndFuseResult;
}

View File

@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
SplitPadding.cpp
SwapExtractSliceWithProducer.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
@ -18,5 +19,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRPass
MLIRSCFDialect
MLIRTensorDialect
MLIRTilingInterface
MLIRTransforms
)

View File

@ -0,0 +1,43 @@
//===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Swap a `tensor.extract_slice` with the producer of the source if the producer
// implements the `TilingInterface`. When used in conjunction with tiling this
// effectively tiles + fuses the producer with its consumer.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
using namespace mlir;
FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
if (!producerOp)
return failure();
// `TilingInterface` currently only supports strides being 1.
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 1);
}))
return failure();
FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
builder, producer.getResultNumber(),
producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), true);
if (failed(tiledResult))
return failure();
return tiledResult.getValue();
}

View File

@ -0,0 +1,185 @@
// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s
func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm : tensor<?x?xf32>
}
// CHECK: func.func @gemm_fill_fusion(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK: %[[INIT:.+]] = linalg.init_tensor
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME: outs(%[[FILL_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
// CHECK scf.yield %[[INSERT]]
// -----
func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm = linalg.matmul
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%generic = linalg.generic {
__internal_linalg_transform__ = "fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%add = arith.addf %b0, %b1 : f32
linalg.yield %add : f32
} -> tensor<?x?xf32>
return %generic : tensor<?x?xf32>
}
// CHECK: func.func @gemm_generic_fusion(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
// CHECK: %[[INIT:.+]] = linalg.init_tensor
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT_TILE]] :
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME: outs(%[[FILL_TILE]] :
// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
// CHECK-SAME: outs(%[[OUTS_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
// CHECK scf.yield %[[INSERT]]
// -----
func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
%init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm0 = linalg.matmul
ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
%init1 = linalg.init_tensor [%d0, %d2] : tensor<?x?xf32>
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %gemm1 : tensor<?x?xf32>
}
// CHECK: func.func @gemm_gemm_fusion(
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]]
// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]]
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0]
// CHECK: %[[FILL0_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT0_TILE]] :
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
// CHECK-SAME: outs(%[[FILL1_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
// CHECK scf.yield %[[INSERT]]
// -----
func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
%gemm = linalg.matmul
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%init1 = linalg.init_tensor [%d1, %d0] : tensor<?x?xf32>
%transpose = linalg.generic {
__internal_linalg_transform__ = "fusion",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
linalg.yield %b0 : f32
} -> tensor<?x?xf32>
return %transpose : tensor<?x?xf32>
}
// CHECK: func.func @gemm_transpose_fusion(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]])
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT0_TILE]] :
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME: outs(%[[FILL_TILE]] :
// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GEMM_TILE]] :
// CHECK-SAME: outs(%[[OUTS_TILE]] :
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
// CHECK scf.yield %[[INSERT]]

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s
func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {

View File

@ -29,8 +29,9 @@ using namespace mlir;
namespace {
/// Construct a generic pattern applied to all TilingInterface ops that verify
/// `filter`.
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
/// using a `filter` to avoid recursive application.
struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
TestTileUsingSCFForOpWithFilter(MLIRContext *context,
scf::SCFTilingOptions options,
@ -52,8 +53,7 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
FailureOr<scf::SCFTilingResult> tilingResult =
returningMatchAndRewrite(op, rewriter);
auto tilingResult = returningMatchAndRewrite(op, rewriter);
if (failed(tilingResult)) {
return failure();
}
@ -65,6 +65,50 @@ private:
linalg::LinalgTransformationFilter filter;
};
/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
/// application.
struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
: public scf::TileConsumerAndFuseProducersUsingSCFForOp {
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
MLIRContext *context, scf::SCFTilingOptions options,
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
benefit),
filter(filter) {}
/// Construct a generic pattern applied to `opName`.
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
benefit),
filter(filter) {}
LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
if (failed(tileAndFuseResult)) {
return failure();
}
filter.replaceLinalgTransformationFilter(
rewriter, tileAndFuseResult->tiledAndFusedOps.front());
return success();
}
private:
linalg::LinalgTransformationFilter filter;
};
/// Test pass for testing the use of `TilingInterface`.
struct TestTilingInterfacePass
: public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
@ -82,29 +126,63 @@ struct TestTilingInterfacePass
return "Test tiling using TilingInterface";
}
Option<bool> testTiling{
*this, "tile-using-scf-for",
llvm::cl::desc(
"Test tiling using TilingInterface with scf.for operations"),
llvm::cl::init(false)};
Option<bool> testTileConsumerAndFuseProducer{
*this, "tile-consumer-and-fuse-producer-using-scf-for",
llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
"with scf.for operations"),
llvm::cl::init(false)};
void runOnOperation() override;
private:
void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns);
};
} // namespace
static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
StringRef filterName) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);
linalg::LinalgTransformationFilter filter(
StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
filter);
};
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
addPatternForTiling({10, 20}, "simple_gemm");
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
// 3. Tiling 3D parallel generic op which implements a transpose
addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
// 4. Tiling 2D conv op.
addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
template <class Pattern>
static void
addPatternForTiling(MLIRContext *context, ArrayRef<int64_t> tileSizes,
StringRef filterName, RewritePatternSet &patterns) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);
linalg::LinalgTransformationFilter filter(
StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
patterns.add<Pattern>(context, tilingOptions, filter);
}
void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
if (testTiling) {
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, {10, 20}, "simple_gemm", patterns);
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, {10, 20, 30}, "simple_gemm_memref", patterns);
// 3. Tiling 3D parallel generic op which implements a transpose
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, {10, 0, 20}, "parallel_generic_transpose", patterns);
// 4. Tiling 2D conv op.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
return;
}
if (testTileConsumerAndFuseProducer) {
// 1. Tile and fuse of gemm with bias-add operation.
addPatternForTiling<
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
context, {10, 20}, "fusion", patterns);
addPatternForTiling<
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
context, {10}, "gemm_fusion", patterns);
return;
}
}
void TestTilingInterfacePass::runOnOperation() {

View File

@ -1881,6 +1881,7 @@ cc_library(
":SCFUtils",
":Support",
":TensorDialect",
":TensorTransforms",
":TilingInterface",
":Transforms",
"//llvm:Support",
@ -5028,6 +5029,7 @@ cc_library(
":SCFDialect",
":TensorDialect",
":TensorPassIncGen",
":TilingInterface",
":Transforms",
"//llvm:Support",
],