forked from OSchip/llvm-project
197 lines
7.9 KiB
C++
197 lines
7.9 KiB
C++
//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
namespace {
|
|
|
|
/// External model implementation of TilingInterface for LinalgOps. An external
|
|
/// model implementation is used for now till the use of `TilingInterface` is
|
|
/// on-par with the current Linalg tiling + fusion patterns. Once it is
|
|
/// maybe possible to move this into the op-definition (though there are
|
|
/// advantages to leaving it as an external model)
|
|
template <typename LinalgOpTy>
|
|
struct LinalgOpTilingInterface
|
|
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
|
|
LinalgOpTy> {
|
|
/// Return the destination operands.
|
|
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
|
|
return cast<LinalgOp>(op).getOutputOperands();
|
|
}
|
|
|
|
/// Return the loop iterator type.
|
|
SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
|
|
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
|
|
return llvm::to_vector(
|
|
llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
|
|
return strAttr.cast<StringAttr>().getValue();
|
|
}));
|
|
}
|
|
|
|
/// Return the iteration domain range.
|
|
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
|
|
OpBuilder::InsertionGuard g(b);
|
|
b.setInsertionPoint(op);
|
|
Location loc = op->getLoc();
|
|
LinalgOp linalgOp = cast<LinalgOp>(op);
|
|
SmallVector<OpFoldResult> allShapesSizes =
|
|
linalgOp.createFlatListOfOperandDims(b, loc);
|
|
AffineMap map = linalgOp.getShapesToLoopsMap();
|
|
|
|
IRRewriter rewriter(b);
|
|
return llvm::to_vector(
|
|
llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
|
|
OpFoldResult ofr = makeComposedFoldedAffineApply(
|
|
rewriter, loc, loopExpr, allShapesSizes);
|
|
return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
|
|
}));
|
|
}
|
|
|
|
// Instantiate the tiled implementation of the operation.
|
|
SmallVector<Operation *>
|
|
getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes,
|
|
bool tileDestOperands) const {
|
|
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
|
|
// specified could lead to out of bounds accesses.
|
|
Location loc = op->getLoc();
|
|
LinalgOp linalgOp = cast<LinalgOp>(op);
|
|
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
|
|
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
|
|
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
|
|
|
|
SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
|
|
linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
|
|
return tiledOperands[opOperand->getOperandNumber()].getType();
|
|
}));
|
|
|
|
Operation *tiledOp =
|
|
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
|
|
offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
|
|
|
|
return {tiledOp};
|
|
}
|
|
|
|
// Return the details of the output tile generated by the tiled
|
|
// implementation.
|
|
LogicalResult
|
|
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes,
|
|
SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) const {
|
|
Location loc = op->getLoc();
|
|
LinalgOp linalgOp = cast<LinalgOp>(op);
|
|
|
|
AffineExpr d0;
|
|
bindDims(b.getContext(), d0);
|
|
IRRewriter rewriter(b);
|
|
SmallVector<OpFoldResult> subShapeSizes =
|
|
llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) {
|
|
return makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, ofr);
|
|
}));
|
|
|
|
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
|
|
Value sliceOpResult =
|
|
makeTiledShape(b, loc, outOperand->get(), sizes,
|
|
linalgOp.getTiedIndexingMap(outOperand), offsets,
|
|
/*ubs*/ {}, subShapeSizes, true);
|
|
auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
|
|
if (!sliceOp)
|
|
return failure();
|
|
resultOffsets = sliceOp.getMixedOffsets();
|
|
resultSizes = sliceOp.getMixedSizes();
|
|
return success();
|
|
}
|
|
|
|
FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
|
|
unsigned resultNumber,
|
|
ValueRange dest,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes,
|
|
bool tileDestOperands) const {
|
|
auto linalgOp = cast<LinalgOp>(op);
|
|
|
|
// Check that the indexing map used for the output is a projected
|
|
// permutation. This could be relaxed with a more general approach that can
|
|
// map the offsets and sizes from the result to iteration space tiles
|
|
// (filling in full extent for dimensions not used to access the result).
|
|
AffineMap indexingMap =
|
|
linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
|
|
if (!indexingMap.isProjectedPermutation()) {
|
|
return op->emitOpError(
|
|
"unhandled tiled implementation generation when result is not "
|
|
"accessed using a permuted projection");
|
|
}
|
|
|
|
auto numLoops = linalgOp.getNumLoops();
|
|
auto tilingInterfaceOp = cast<TilingInterface>(op);
|
|
SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
|
|
iterationTileSizes(numLoops);
|
|
if (!indexingMap.isPermutation()) {
|
|
SmallVector<Range> iterationDomain =
|
|
tilingInterfaceOp.getIterationDomain(b);
|
|
for (const auto &range : llvm::enumerate(iterationDomain)) {
|
|
iterationTileOffsets[range.index()] = range.value().offset;
|
|
iterationTileSizes[range.index()] = range.value().size;
|
|
}
|
|
}
|
|
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
|
|
unsigned dimPosition =
|
|
resultExpr.value().cast<AffineDimExpr>().getPosition();
|
|
iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
|
|
iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
|
|
}
|
|
|
|
SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
|
|
b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
|
|
if (tiledOp.size() != 1)
|
|
return op->emitOpError("failed to generate tiled implementation");
|
|
|
|
return tiledOp[0]->getResult(resultNumber);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
template <typename OpType>
|
|
static void registerOne(MLIRContext *ctx) {
|
|
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
|
|
}
|
|
|
|
/// Variadic helper function.
|
|
template <typename... OpTypes>
|
|
static void registerAll(MLIRContext *ctx) {
|
|
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
|
(void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
|
|
}
|
|
|
|
#define GET_OP_LIST
|
|
|
|
void mlir::linalg::registerTilingInterfaceExternalModels(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
|
|
registerOne<linalg::GenericOp>(ctx);
|
|
registerAll<
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
>(ctx);
|
|
});
|
|
}
|