forked from OSchip/llvm-project
[mlir][TilingInterface] Enable tile and fuse using TilingInterface.
This patch implements tile and fuse transformation for ops that implement the tiling interface. To do so, - `TilingInterface` needs a new method that generates a tiled implementation of the operation based on the tile of the result needed. - A pattern is added that replaces a `tensor.extract_slice` whose source is defined by an operation that implements the `TilingInterface` with a tiled implementation that produces the extracted slice in-place (using the method added to `TilingInterface`). - A pattern is added that takes a sequence of operations that implement the `TilingInterface` (for now `LinalgOp`s), tiles the consumer, and greedily fuses its producers iteratively. Differential Revision: https://reviews.llvm.org/D127809
This commit is contained in:
parent
d4ee43153d
commit
2f637fe730
|
@ -10,9 +10,12 @@
|
|||
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
|
||||
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
|
||||
#include <deque>
|
||||
|
||||
namespace mlir {
|
||||
class Operation;
|
||||
class PatternRewriter;
|
||||
|
@ -55,7 +58,7 @@ struct SCFTilingResult {
|
|||
SmallVector<scf::ForOp> loops;
|
||||
};
|
||||
|
||||
/// Pattern to tile an op that implementas the `TilingInterface` using
|
||||
/// Pattern to tile an op that implements the `TilingInterface` using
|
||||
/// `scf.for` for iterating over the tiles.
|
||||
struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
|
||||
/// Construct a generic pattern applied to all TilingInterface ops.
|
||||
|
@ -81,6 +84,56 @@ private:
|
|||
SCFTilingOptions options;
|
||||
};
|
||||
|
||||
/// Pattern to tile and fuse a sequence of operations, by tiling the consumer
|
||||
/// and fusing its producers. Note that this assumes that it is valid to
|
||||
/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
|
||||
/// to ensure that the tile sizes provided make this fusion valid.
|
||||
///
|
||||
/// For example, for the following sequence
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = linalg.fill ...
|
||||
/// %1 = linalg.matmul ... outs(%0 : ...) ...
|
||||
/// ```
|
||||
///
|
||||
/// it is legal to fuse the fill with the matmul only if the matmul is tiled
|
||||
/// along the parallel dimensions and not the reduction dimension, i.e. the tile
|
||||
/// size for the reduction dimension should be 0.
|
||||
struct SCFTileAndFuseResult {
|
||||
SmallVector<Operation *> tiledAndFusedOps;
|
||||
SmallVector<scf::ForOp> loops;
|
||||
};
|
||||
struct TileConsumerAndFuseProducersUsingSCFForOp
|
||||
: public OpInterfaceRewritePattern<TilingInterface> {
|
||||
|
||||
/// Construct a generic pattern applied to all TilingInterface ops.
|
||||
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
|
||||
SCFTilingOptions options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Construct a generic pattern applied to `opName`.
|
||||
TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
|
||||
MLIRContext *context,
|
||||
SCFTilingOptions options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// `matchAndRewrite` implementation that returns the significant transformed
|
||||
/// pieces of IR.
|
||||
FailureOr<SCFTileAndFuseResult>
|
||||
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
|
||||
|
||||
LogicalResult matchAndRewrite(TilingInterface op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return returningMatchAndRewrite(op, rewriter);
|
||||
}
|
||||
|
||||
private:
|
||||
/// This pattern uses the tiling pattern. Instead of using inheritance, use
|
||||
/// the patterns as private object that is instantiated at the same time as
|
||||
/// this pattern.
|
||||
TileUsingSCFForOp tilingPattern;
|
||||
};
|
||||
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
|
||||
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -20,6 +21,14 @@ namespace tensor {
|
|||
void populateSplitPaddingPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit baseBenefit = 1);
|
||||
|
||||
/// Pattern to swap an `tensor.extract_slice` with its producer when the
|
||||
/// producer implements the `TilingInterface`. The pattern itself does not
|
||||
/// provide a mechanism to control where the application happens. With use of
|
||||
/// transform dialect that control is done within the transform dialect. Other
|
||||
/// use cases can inherit from this pattern and add necessary controls.
|
||||
FailureOr<Value> replaceExtractSliceWithTiledProducer(
|
||||
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -120,7 +120,48 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Method to generate the code that produces a tile of the result.
|
||||
|
||||
Generates the IR that computes the tile of a result of the
|
||||
operation. The `offsets` and `sizes` describe the tile of
|
||||
the output required. This is different from
|
||||
`getTiledImplementation` which generates the tiled
|
||||
implementation of the operation given a tile of the
|
||||
iteration space. This method generates a tiled
|
||||
implementation of the operation based on the tile of the
|
||||
result required. This method enables fusion by using tile
|
||||
and fuse. The method returns failure if the operation can't be
|
||||
tiled to generate the result tile. In practical terms this
|
||||
implies it cannot be tiled and fused with its consumers.
|
||||
|
||||
- `dest` are the Value into which the result of the tiled
|
||||
operation is to be inserted into. The type of the `dest`
|
||||
Values is same as the types returned by
|
||||
`getDestinationOperands` method.
|
||||
- `offsets` provides the offset of the tile within the
|
||||
iteration space
|
||||
- `sizes` provides the size of the tile.
|
||||
- `tileDestOperands` specifies whether to also tile `dest` operands
|
||||
or not. Avoiding tiling `dest` operands can be useful for
|
||||
composition with various looping container ops.
|
||||
}],
|
||||
/*retType=*/"FailureOr<Value>",
|
||||
/*methodName=*/"generateResultTileValue",
|
||||
/*args=*/(ins
|
||||
"OpBuilder &":$b,
|
||||
"unsigned":$resultNumber,
|
||||
"ValueRange":$dest,
|
||||
"ArrayRef<OpFoldResult>":$offsets,
|
||||
"ArrayRef<OpFoldResult>":$sizes,
|
||||
"bool":$tileDestOperands),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>
|
||||
];
|
||||
];
|
||||
}
|
||||
#endif // MLIR_TILINGINTERFACE
|
||||
|
|
|
@ -30,7 +30,6 @@ template <typename LinalgOpTy>
|
|||
struct LinalgOpTilingInterface
|
||||
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
|
||||
LinalgOpTy> {
|
||||
|
||||
/// Return the destination operands.
|
||||
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
|
||||
return llvm::cast<LinalgOp>(op).getOutputOperands();
|
||||
|
@ -47,6 +46,8 @@ struct LinalgOpTilingInterface
|
|||
|
||||
/// Return the iteration domain range.
|
||||
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(op);
|
||||
Location loc = op->getLoc();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
|
||||
|
@ -129,16 +130,65 @@ struct LinalgOpTilingInterface
|
|||
resultSizes = sliceOp.getMixedSizes();
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
|
||||
unsigned resultNumber,
|
||||
ValueRange dest,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
bool tileDestOperands) const {
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
// Check that the indexing map used for the output is a projected
|
||||
// permutation. This could be relaxed with a more general approach that can
|
||||
// map the offsets and sizes from the result to iteration space tiles
|
||||
// (filling in full extent for dimensions not used to access the result).
|
||||
AffineMap indexingMap =
|
||||
linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
|
||||
if (!indexingMap.isProjectedPermutation()) {
|
||||
return op->emitOpError(
|
||||
"unhandled tiled implementation generation when result is not "
|
||||
"accessed using a permuted projection");
|
||||
}
|
||||
|
||||
auto numLoops = linalgOp.getNumLoops();
|
||||
auto tilingInterfaceOp = cast<TilingInterface>(op);
|
||||
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
|
||||
iterationTileSizes(numLoops);
|
||||
if (!indexingMap.isPermutation()) {
|
||||
SmallVector<Range> iterationDomain =
|
||||
tilingInterfaceOp.getIterationDomain(b);
|
||||
for (auto range : llvm::enumerate(iterationDomain)) {
|
||||
iterationTileOffsets[range.index()] = range.value().offset;
|
||||
iterationTileSizes[range.index()] = range.value().size;
|
||||
}
|
||||
}
|
||||
for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) {
|
||||
unsigned dimPosition =
|
||||
resultExpr.value().cast<AffineDimExpr>().getPosition();
|
||||
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
|
||||
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
|
||||
}
|
||||
|
||||
SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
|
||||
b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
|
||||
if (tiledOp.size() != 1)
|
||||
return op->emitOpError("failed to generate tiled implementation");
|
||||
|
||||
return tiledOp[0]->getResult(resultNumber);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename OpType> static void registerOne(MLIRContext *ctx) {
|
||||
template <typename OpType>
|
||||
static void registerOne(MLIRContext *ctx) {
|
||||
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
|
||||
}
|
||||
|
||||
/// Variadic helper function.
|
||||
template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
|
||||
template <typename... OpTypes>
|
||||
static void registerAll(MLIRContext *ctx) {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
|||
MLIRSCFUtils
|
||||
MLIRSupport
|
||||
MLIRTensorDialect
|
||||
MLIRTensorTransforms
|
||||
MLIRTilingInterface
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
|
|
|
@ -42,6 +42,10 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TileUsingSCFForOp pattern implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Generate an empty loop nest that represents the tiled loop nest shell.
|
||||
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
|
||||
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
|
||||
|
@ -247,3 +251,155 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
|
|||
rewriter.replaceOp(op, tilingResult.loops.front().getResults());
|
||||
return tilingResult;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
scf::TileConsumerAndFuseProducersUsingSCFForOp::
|
||||
TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
PatternBenefit benefit)
|
||||
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
|
||||
tilingPattern(context, std::move(options)) {}
|
||||
|
||||
scf::TileConsumerAndFuseProducersUsingSCFForOp::
|
||||
TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
|
||||
MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
PatternBenefit benefit)
|
||||
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
|
||||
tilingPattern(context, std::move(options)) {}
|
||||
|
||||
/// Return the `Value` that is defined by an operation that implements
|
||||
/// the `TilingInterface`. Looks through `iter_args` of scf.for nest
|
||||
/// if required.
|
||||
static Optional<OpResult> getFusableProducer(Value v) {
|
||||
while (auto blockArg = v.dyn_cast<BlockArgument>()) {
|
||||
auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
|
||||
if (!loopOp)
|
||||
return llvm::None;
|
||||
v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
|
||||
}
|
||||
if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
|
||||
return llvm::None;
|
||||
return v.cast<OpResult>();
|
||||
}
|
||||
|
||||
FailureOr<scf::SCFTileAndFuseResult>
|
||||
scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
|
||||
TilingInterface op, PatternRewriter &rewriter) const {
|
||||
// This transformation is only valid for ops that return values (i.e. not
|
||||
// valid to use with operations that have memref operands).
|
||||
if (!op->getNumResults()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid pattern for op with no results");
|
||||
}
|
||||
|
||||
// 1. First tile the consumer.
|
||||
SCFTileAndFuseResult tileAndFuseResult;
|
||||
{
|
||||
FailureOr<SCFTilingResult> tilingResult =
|
||||
tilingPattern.returningMatchAndRewrite(op, rewriter);
|
||||
if (failed(tilingResult)) {
|
||||
return failure();
|
||||
}
|
||||
tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
|
||||
tileAndFuseResult.loops = std::move(tilingResult->loops);
|
||||
}
|
||||
|
||||
// 2. Typically, the operands of the tiled operation are slices of the
|
||||
// operands of the untiled operation. These are expressed in IR using
|
||||
// `tensor.extract_slice` operations with source being the operands of the
|
||||
// untiled operation. Create a worklist of these `tensor.extract_slice`
|
||||
// operations. If the producers of the source of the `tensor.extract_slice`
|
||||
// can be tiled such that the tiled value is generated in-place, that
|
||||
// effectively tiles + fuses the operations.
|
||||
auto addCandidateSlices = [](Operation *fusedOp,
|
||||
std::deque<tensor::ExtractSliceOp> &candidates) {
|
||||
for (Value operand : fusedOp->getOperands())
|
||||
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
|
||||
candidates.push_back(sliceOp);
|
||||
};
|
||||
|
||||
std::deque<tensor::ExtractSliceOp> candidates;
|
||||
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
while (!candidates.empty()) {
|
||||
// 2a. Traverse the slices in BFS fashion.
|
||||
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
|
||||
candidates.pop_front();
|
||||
|
||||
// 2b. Get the producer of the source (potentially walking through
|
||||
// `iter_args` of nested `scf.for`)
|
||||
Optional<OpResult> fusableProducer =
|
||||
getFusableProducer(candidateSliceOp.source());
|
||||
if (!fusableProducer)
|
||||
continue;
|
||||
|
||||
// 2c. Generate the tiled implementation of the producer of the source
|
||||
rewriter.setInsertionPoint(candidateSliceOp);
|
||||
FailureOr<Value> fusedProducerValue =
|
||||
tensor::replaceExtractSliceWithTiledProducer(
|
||||
rewriter, candidateSliceOp, fusableProducer.getValue());
|
||||
if (failed(fusedProducerValue))
|
||||
continue;
|
||||
rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue());
|
||||
|
||||
// 2d. The operands of the fused producer might themselved be slices of
|
||||
// values produced by operations that implement the `TilingInterface`.
|
||||
// Add these operations to the worklist.
|
||||
Operation *fusedProducer = fusedProducerValue->getDefiningOp();
|
||||
tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
|
||||
addCandidateSlices(fusedProducer, candidates);
|
||||
|
||||
// 2e. If the operation being fused creates a value that is used as `outs`
|
||||
// in the tiled operation, the result of the unfused operation will be
|
||||
// used in the `iter_args` of the tiled loop generated. When the
|
||||
// operation is fused, this use in `iter_args` needs to be modified to
|
||||
// use the destination of the fused operation. For example, starting
|
||||
// with
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = linalg.init_tensor ...
|
||||
// %1 = linalg.fill ... outs(%0:...)...
|
||||
// %2 = linalg.matmul ... outs(%1:...)....
|
||||
// ```
|
||||
//
|
||||
// First the `linalg.matmul` gets tiled
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = linalg.init_tensor
|
||||
// %1 = linalg.fill
|
||||
// %2 = scf.for .... iter_args(%arg0 = %1)...
|
||||
// ...
|
||||
// ... = linalg.matmul ...
|
||||
//
|
||||
// ```
|
||||
//
|
||||
// When the `linalg.fill` gets fused, the `iter_args` needs to be
|
||||
// modified
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = linalg.init_tensor
|
||||
// %1 = scf.for ... iter_args(%arg0 = %0)...
|
||||
// ...
|
||||
// %2 = linalg.fill ...
|
||||
// %3 = linalg.matmul ... outs(%2: ...)...
|
||||
// ```
|
||||
TilingInterface unfusedProducerOp =
|
||||
cast<TilingInterface>(fusableProducer->getOwner());
|
||||
scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
|
||||
SmallVector<Value> unfusedProducerOpDestValues =
|
||||
unfusedProducerOp.getDestinationOperands(rewriter);
|
||||
for (OpOperand &uses : unfusedProducerOp->getUses()) {
|
||||
if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
|
||||
unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
|
||||
unsigned operandNumber = uses.getOperandNumber();
|
||||
outerMostTiledLoop->setOperand(
|
||||
operandNumber, unfusedProducerOpDestValues[resultNumber]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tileAndFuseResult;
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
|
|||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
SplitPadding.cpp
|
||||
SwapExtractSliceWithProducer.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
|
||||
|
@ -18,5 +19,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
|
|||
MLIRPass
|
||||
MLIRSCFDialect
|
||||
MLIRTensorDialect
|
||||
MLIRTilingInterface
|
||||
MLIRTransforms
|
||||
)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
//===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Swap a `tensor.extract_slice` with the producer of the source if the producer
|
||||
// implements the `TilingInterface`. When used in conjunction with tiling this
|
||||
// effectively tiles + fuses the producer with its consumer.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
|
||||
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
|
||||
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
|
||||
if (!producerOp)
|
||||
return failure();
|
||||
|
||||
// `TilingInterface` currently only supports strides being 1.
|
||||
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
|
||||
return !isConstantIntValue(ofr, 1);
|
||||
}))
|
||||
return failure();
|
||||
|
||||
FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
|
||||
builder, producer.getResultNumber(),
|
||||
producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
|
||||
sliceOp.getMixedSizes(), true);
|
||||
if (failed(tiledResult))
|
||||
return failure();
|
||||
|
||||
return tiledResult.getValue();
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %gemm : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: func.func @gemm_fill_fusion(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
|
||||
// CHECK-SAME: outs(%[[FILL_TILE]] :
|
||||
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK scf.yield %[[INSERT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%gemm = linalg.matmul
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%generic = linalg.generic {
|
||||
__internal_linalg_transform__ = "fusion",
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
%add = arith.addf %b0, %b1 : f32
|
||||
linalg.yield %add : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %generic : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: func.func @gemm_generic_fusion(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
|
||||
// CHECK-SAME: outs(%[[FILL_TILE]] :
|
||||
// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
|
||||
// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
|
||||
// CHECK-SAME: outs(%[[OUTS_TILE]] :
|
||||
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK scf.yield %[[INSERT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
|
||||
%d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
|
||||
%init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%gemm0 = linalg.matmul
|
||||
ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
|
||||
%init1 = linalg.init_tensor [%d0, %d2] : tensor<?x?xf32>
|
||||
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
|
||||
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %gemm1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: func.func @gemm_gemm_fusion(
|
||||
// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]]
|
||||
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]]
|
||||
// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
|
||||
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
|
||||
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]]
|
||||
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
|
||||
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
|
||||
// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
|
||||
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0]
|
||||
// CHECK: %[[FILL0_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT0_TILE]] :
|
||||
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
|
||||
// CHECK-SAME: outs(%[[FILL0_TILE]] :
|
||||
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
|
||||
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
|
||||
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT1_TILE]] :
|
||||
// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
|
||||
// CHECK-SAME: outs(%[[FILL1_TILE]] :
|
||||
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
|
||||
// CHECK scf.yield %[[INSERT]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%gemm = linalg.matmul
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%init1 = linalg.init_tensor [%d1, %d0] : tensor<?x?xf32>
|
||||
%transpose = linalg.generic {
|
||||
__internal_linalg_transform__ = "fusion",
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32):
|
||||
linalg.yield %b0 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %transpose : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: func.func @gemm_transpose_fusion(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
|
||||
// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]])
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
|
||||
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
|
||||
// CHECK-SAME: outs(%[[INIT0_TILE]] :
|
||||
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
|
||||
// CHECK-SAME: outs(%[[FILL_TILE]] :
|
||||
// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
|
||||
// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[GEMM_TILE]] :
|
||||
// CHECK-SAME: outs(%[[OUTS_TILE]] :
|
||||
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
|
||||
// CHECK scf.yield %[[INSERT]]
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
|
|
@ -29,8 +29,9 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
/// Construct a generic pattern applied to all TilingInterface ops that verify
|
||||
/// `filter`.
|
||||
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
|
||||
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
|
||||
/// using a `filter` to avoid recursive application.
|
||||
struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
|
||||
TestTileUsingSCFForOpWithFilter(MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
|
@ -52,8 +53,7 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
|
|||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
|
||||
FailureOr<scf::SCFTilingResult> tilingResult =
|
||||
returningMatchAndRewrite(op, rewriter);
|
||||
auto tilingResult = returningMatchAndRewrite(op, rewriter);
|
||||
if (failed(tilingResult)) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -65,6 +65,50 @@ private:
|
|||
linalg::LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
|
||||
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
|
||||
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
|
||||
/// application.
|
||||
struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
|
||||
: public scf::TileConsumerAndFuseProducersUsingSCFForOp {
|
||||
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
|
||||
MLIRContext *context, scf::SCFTilingOptions options,
|
||||
linalg::LinalgTransformationFilter filter =
|
||||
linalg::LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
|
||||
benefit),
|
||||
filter(filter) {}
|
||||
|
||||
/// Construct a generic pattern applied to `opName`.
|
||||
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
|
||||
StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
|
||||
linalg::LinalgTransformationFilter filter =
|
||||
linalg::LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
|
||||
benefit),
|
||||
filter(filter) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TilingInterface op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
|
||||
auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
|
||||
if (failed(tileAndFuseResult)) {
|
||||
return failure();
|
||||
}
|
||||
filter.replaceLinalgTransformationFilter(
|
||||
rewriter, tileAndFuseResult->tiledAndFusedOps.front());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
linalg::LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
/// Test pass for testing the use of `TilingInterface`.
|
||||
struct TestTilingInterfacePass
|
||||
: public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
|
||||
|
@ -82,29 +126,63 @@ struct TestTilingInterfacePass
|
|||
return "Test tiling using TilingInterface";
|
||||
}
|
||||
|
||||
Option<bool> testTiling{
|
||||
*this, "tile-using-scf-for",
|
||||
llvm::cl::desc(
|
||||
"Test tiling using TilingInterface with scf.for operations"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> testTileConsumerAndFuseProducer{
|
||||
*this, "tile-consumer-and-fuse-producer-using-scf-for",
|
||||
llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
|
||||
"with scf.for operations"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
|
||||
auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
|
||||
StringRef filterName) {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizes(tileSizes);
|
||||
linalg::LinalgTransformationFilter filter(
|
||||
StringAttr::get(context, filterName),
|
||||
StringAttr::get(context, "tiled"));
|
||||
patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
|
||||
filter);
|
||||
};
|
||||
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
|
||||
addPatternForTiling({10, 20}, "simple_gemm");
|
||||
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
|
||||
addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
|
||||
// 3. Tiling 3D parallel generic op which implements a transpose
|
||||
addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
|
||||
// 4. Tiling 2D conv op.
|
||||
addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
|
||||
template <class Pattern>
|
||||
static void
|
||||
addPatternForTiling(MLIRContext *context, ArrayRef<int64_t> tileSizes,
|
||||
StringRef filterName, RewritePatternSet &patterns) {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizes(tileSizes);
|
||||
linalg::LinalgTransformationFilter filter(
|
||||
StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
|
||||
patterns.add<Pattern>(context, tilingOptions, filter);
|
||||
}
|
||||
|
||||
void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
|
||||
RewritePatternSet &patterns) {
|
||||
if (testTiling) {
|
||||
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
|
||||
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
||||
context, {10, 20}, "simple_gemm", patterns);
|
||||
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
|
||||
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
||||
context, {10, 20, 30}, "simple_gemm_memref", patterns);
|
||||
// 3. Tiling 3D parallel generic op which implements a transpose
|
||||
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
||||
context, {10, 0, 20}, "parallel_generic_transpose", patterns);
|
||||
// 4. Tiling 2D conv op.
|
||||
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
|
||||
context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
|
||||
return;
|
||||
}
|
||||
if (testTileConsumerAndFuseProducer) {
|
||||
// 1. Tile and fuse of gemm with bias-add operation.
|
||||
addPatternForTiling<
|
||||
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
||||
context, {10, 20}, "fusion", patterns);
|
||||
addPatternForTiling<
|
||||
TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
|
||||
context, {10}, "gemm_fusion", patterns);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void TestTilingInterfacePass::runOnOperation() {
|
||||
|
|
|
@ -1881,6 +1881,7 @@ cc_library(
|
|||
":SCFUtils",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":TensorTransforms",
|
||||
":TilingInterface",
|
||||
":Transforms",
|
||||
"//llvm:Support",
|
||||
|
@ -5028,6 +5029,7 @@ cc_library(
|
|||
":SCFDialect",
|
||||
":TensorDialect",
|
||||
":TensorPassIncGen",
|
||||
":TilingInterface",
|
||||
":Transforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue