llvm-project/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

523 lines
22 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements linalg fusion on tensors
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace linalg;
//===----------------------------------------------------------------------===//
// StructuredOp specific helpers.
//===----------------------------------------------------------------------===//
/// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
/// The slice defines a hyper rectangular iteration space and fusing the
/// producer is always possible. However, depending on the consumer indexing
/// map, not all slice elements may be consumed and the tiles may overlap. In
/// these cases, fusion introduces redundant computation.
static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
ArrayRef<int64_t> tiledLoopDims) {
// Get the consumer operand indexing map.
LinalgOp consumerOp = consumerOperand->getOwner();
AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
// Search the slice dimensions tiled by a tile loop dimension.
DenseSet<int64_t> tiledSliceDimIndices;
for (auto en : enumerate(indexingMap.getResults())) {
for (auto tiledLoopDim : tiledLoopDims) {
if (en.value().isFunctionOfDim(tiledLoopDim))
tiledSliceDimIndices.insert(en.index());
}
}
return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
}
/// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
/// of the producer result slice returns the tiled producer loop dimensions.
/// Example:
/// ```
/// %res = linalg.fill(%cst, %input)
/// scf.for %i
/// scf.for %j
/// %slice = tensor.extract_slice %res[%i, %j]
/// ```
/// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
static SmallVector<int64_t>
getTiledProducerLoops(OpResult producerResult,
ArrayRef<int64_t> tiledSliceDimIndices) {
LinalgOp producerOp = producerResult.getOwner();
// Get the indexing map of the `producerOp` output operand that matches
// ´producerResult´.
AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
producerOp.getOutputOperand(producerResult.getResultNumber()));
// Keep only the tiled result slice dimensions of `producerIndexingMap`.
AffineMap tiledProducerIndexingSubMap =
producerIndexingMap.getSubMap(SmallVector<unsigned>(
tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
// Compute the producer loop indices mapped to the tiled result slice
// dimensions. As the output indexing map of structured operations are
// projected permutations, `tiledProducerIndexingSubMap` has to be a
// projected permutation as well. We can thus obtain the producer loop indices
// by getting the positions of the result dimensions.
// Example:
// (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
"expect slice and producer loop dimensions map one-to-one");
SmallVector<int64_t> tiledProducerLoopIndices;
transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
return tiledProducerIndexingSubMap.getDimPosition(idx);
});
return tiledProducerLoopIndices;
}
/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
/// along the `tiledSliceDimIndices` and clone the producer. Consider the case
/// of fusion of an output tensor:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = consumer ins(...) outs(%1)
/// ```
/// When consumer is tiled, %1 appears in the loop iter_args:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
/// %t1 = tensor.extract_slice %bbarg[..]
/// %t2 = consumer ins(...) outs(%t1)
/// %r = tensor.insert_slice %t2, %bbarg[...]
/// }
/// ```
/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
/// ```
/// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
/// %t0 = tensor.extract_slice %bbarg[..]
/// %t1 = producer ins(...) outs(%t0)
/// %t2 = consumer ins(...) outs(%t1)
/// %r = tensor.insert_slice %t2, %bbarg[...]
/// }
/// ```
/// This transformation is only valid if %bbarg is exclusively used by the
/// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
/// `fuseProducer` method.
/// TODO: instead of check and failure, insert new iter_args each time a
/// producer is fused into a consumer and fold away unused iter_args.
static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
tensor::ExtractSliceOp sliceOp,
ArrayRef<int64_t> tiledSliceDimIndices,
ArrayRef<int64_t> tiledProducerLoopIndices,
OpOperand *iterArg) {
// Clone the producer after `sliceOp` since the slice may be reused to pass in
// the producer result.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfter(sliceOp);
// Get the producer.
LinalgOp producerOp = producerResult.getOwner();
Location loc = producerOp.getLoc();
// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
SmallVector<Value> producerLoopBounds;
transform(producerOp.createLoopRanges(b, loc),
std::back_inserter(producerLoopBounds),
[](Range range) { return range.size; });
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
// Tile the producer operands given the `sliceOp` ranges. Iterate the
// `tiledSliceDimIndices` and store the tile offset and size for the tiled
// slice dimension.
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
int64_t tiledSliceDim = std::get<0>(it);
int64_t tiledProducerLoop = std::get<1>(it);
tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
}
erase_value(tileIvs, nullptr);
SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
tileSizes, producerLoopBounds);
// Output fusion has to update the iteration arguments of the tile loop nest.
// In particular, the iteration argument of the outermost tile loop needs to
// be set to the producer output instead of the producer result and `clonedOp`
// shall use the existing `sliceOp` result instead of the tiled producer
// output operand.
if (iterArg) {
OpOperand *outputOperand =
producerOp.getOutputOperand(producerResult.getResultNumber());
iterArg->set(outputOperand->get());
tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
}
// Clone the producer using the tiled producer operands.
TypeRange resultTypes = ValueRange(tiledOperands)
.take_back(producerOp.getNumOutputs())
.getTypes();
LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
// Shift all IndexOp results by the tile offset.
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
return clonedOp;
}
//===----------------------------------------------------------------------===//
// TileLoopNest specific helpers.
//===----------------------------------------------------------------------===//
bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
bool TileLoopNest::isValid() {
// Check if `rootOp` has been tiled at least once.
if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
return false;
// Check if the number of loop operations and dimensions match.
if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
return false;
// Check if the innermost tile loop is the parent of `tiledOp`.
if (rootOp->getParentOp() != tileLoopOps.back())
return false;
// Check if the tile loops are directly nested.
return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
[](Operation *op1, Operation *op2) {
return op1 != op2->getParentOp();
}) == tileLoopOps.end();
}
SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
assert(bbArg && "expect the block argument to be non-zero");
SmallVector<BlockArgument> bbArgs;
// Search all tile loop block arguments from inner to outer.
for (auto tileLoop : reverse(tileLoopOps)) {
if (bbArg.getOwner()->getParentOp() != tileLoop)
return {};
bbArgs.push_back(bbArg);
OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
bbArg = iterArg->get().dyn_cast<BlockArgument>();
}
// Reverse the block arguments to order them from outer to inner.
return {bbArgs.rbegin(), bbArgs.rend()};
}
OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
// Search all block arguments and return the matching iteration argument.
SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
if (bbArgs.size() != tileLoopOps.size())
return nullptr;
return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
}
bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
tensor::ExtractSliceOp sliceOp) {
// Check the innermost block argument is either used by the ExtractSliceOp
// `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
// conservatively.
for (Operation *op : bbArg.getUsers()) {
if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
return false;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
if (extractSliceOp != sliceOp)
return false;
}
if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
SetVector<Operation *> backwardSlice;
getBackwardSlice(insertSliceOp.source(), &backwardSlice,
[](Operation *op) {
return isa<LinalgOp, tensor::InsertSliceOp>(op);
});
if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
return false;
}
}
// Check the block arguments, except for the innermost one, have one use.
SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
return !all_of(bbArgs, [&](BlockArgument bbArg) {
return bbArg.hasOneUse() || bbArg == bbArgs.back();
});
}
LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange) {
// Exit if all tile sizes are zero.
if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
return success();
// Tile the root operation.
LinalgTilingOptions tilingOptions;
tilingOptions = tilingOptions
.setInterchange(SmallVector<unsigned>(
tileInterchange.begin(), tileInterchange.end()))
.setTileSizes(tileSizes)
.setLoopType(LinalgTilingLoopType::Loops);
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
// Exit if tiling the root operation fails.
if (!tiledRootOp.hasValue())
return failure();
// Replace all uses of the root operation if it has been tiled before. All
// uses of the original untiled root operation are updated by the calling pass
// or pattern.
if (!isEmpty())
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
// Transfer the stored `rootOp` loop dimensions if it has been tiled before.
if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
tiledRootAndFusedOpsLoops[tiledRootOp->op] =
tiledRootAndFusedOpsLoops[rootOp];
}
// Update the root operation and append the loops and tile loop dimensions.
rootOp = tiledRootOp->op;
tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
for (auto en : enumerate(tileSizes)) {
// Copy only the tiled loop dimensions with non-zero tile size.
if (en.value() == 0)
continue;
tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
}
assert(isValid() && "expect tile loop nest to be valid after tiling");
return success();
}
FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
OpOperand *consumerOpOperand) {
assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 &&
"expect the operand owner is the root operation or a fused producer");
assert(this->isValid() &&
"expect the tile loop nest to satisfy all invariants");
// Check the tile loop nest is non-empty.
if (isEmpty())
return failure();
// Check `consumerOpOperand` is defined by an ExtractSliceOp.
auto sliceOp =
consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
return failure();
// Check `sliceOp` and `consumerOp` are in the same block.
LinalgOp consumerOp = consumerOpOperand->getOwner();
if (sliceOp->getBlock() != rootOp->getBlock() ||
consumerOp->getBlock() != rootOp->getBlock())
return failure();
// Check if the producer is a LinalgOp possibly passed by iteration argument.
OpOperand *iterArg = nullptr;
auto producerResult = sliceOp.source().dyn_cast<OpResult>();
if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
iterArg = getTiedIterArg(bbArg);
// Check the iteration argument may be used to pass in the producer output.
if (!iterArg || hasOtherUses(bbArg, sliceOp))
return failure();
producerResult = iterArg->get().dyn_cast<OpResult>();
}
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
// Compute the tiled producer slice dimensions given the tiled consumer loops.
SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
if (tiledSliceDimIndices.empty())
return failure();
// Compute the tiled producer loop indices.
SmallVector<int64_t> tiledProducerLoopIndices =
getTiledProducerLoops(producerResult, tiledSliceDimIndices);
// Tile the producer operands and clone the producer in place of `sliceOp`.
LinalgOp clonedOp =
getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
tiledProducerLoopIndices, iterArg);
tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
// Cast the `clonedOp` result to gap type mismatches before canonicalization.
Type consumerOperandType = consumerOpOperand->get().getType();
Value newResult = clonedOp->getResult(producerResult.getResultNumber());
if (newResult.getType() != consumerOperandType) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfter(clonedOp);
newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
consumerOperandType, newResult);
}
// Replace the `sliceOp` uses except for the `clonedOp` output uses.
sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
return clonedOp;
}
ValueRange TileLoopNest::getRootOpReplacementResults() {
assert(!isEmpty() && "expect tile loop nest to be non-empty");
return tileLoopOps.front()->getOpResults();
}
//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
FailureOr<TileLoopNest>
mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
assert(isPermutation(tileInterchange) &&
"expect tile interchange is a permutation");
// Create an empty tile loop nest.
TileLoopNest tileLoopNest(consumerOp);
// Search the number of outer parallel loops to separate them from possible
// inner reduction dimensions.
SmallVector<StringAttr> iterTypes =
llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
applyPermutationToVector(iterTypes, tileInterchange);
auto *it = find_if(iterTypes, [&](StringAttr iterType) {
return !isParallelIterator(iterType);
});
int64_t split = std::distance(iterTypes.begin(), it);
// Helper to fuse the producers greedily using a queue of fusion candidates.
auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
while (!candidates.empty()) {
FailureOr<LinalgOp> fusedProducer =
tileLoopNest.fuseProducer(b, candidates.pop_back_val());
if (failed(fusedProducer))
continue;
candidates.append(fusedProducer->getInputAndOutputOperands());
}
};
// Tile the outer parallel loops and fuse the output operands.
SmallVector<int64_t> outerTileSizes;
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
outerTileSizes.append(tileSizes.size() - split, 0);
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
// Tile the remaining loops and fuse the input operands.
SmallVector<int64_t> innerTileSizes;
innerTileSizes.append(split, 0);
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
return tileLoopNest;
}
namespace {
struct LinalgTileAndFuseTensorOps
: public LinalgTileAndFuseTensorOpsBase<LinalgTileAndFuseTensorOps> {
void notifyFailure(StringRef message) {
llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n";
signalPassFailure();
}
void runOnFunction() override {
FuncOp funcOp = getFunction();
OpBuilder b(funcOp.getContext());
// Heuristic to find a good operation to tile and start fusion. Walk all
// operations and select the one with the maximal backward slice of fusion
// candidates.
LinalgOp rootOp = nullptr;
int64_t numFusionCandidates = -1;
funcOp.walk([&](LinalgOp linalgOp) {
SetVector<Operation *> backwardSlice;
getBackwardSlice(linalgOp, &backwardSlice);
int64_t backwardSliceSize = count_if(
backwardSlice, [](Operation *op) { return isa<LinalgOp>(op); });
if (backwardSliceSize > numFusionCandidates) {
rootOp = linalgOp;
numFusionCandidates = backwardSliceSize;
}
});
if (!rootOp)
return notifyFailure("expect to find a root operation");
// Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
if (tileSizes.size() < rootOp.getNumLoops())
return notifyFailure("expect #tile sizes >= #loops");
// Check `tileInterchange` contains no entries or as many as `tileSizes`.
if (!tileInterchange.empty() &&
tileInterchange.size() != tileSizes.size()) {
return notifyFailure(
"expect the number of tile sizes and interchange dims to match");
}
// Copy the `tileSizes` and `tileInterchange` prefixes needed to tile
// `rootOp` or use the identity interchange if `tileInterchange` is empty.
SmallVector<int64_t> rootTileSizes(
tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops());
SmallVector<int64_t> rootInterchange =
tileInterchange.empty()
? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
: SmallVector<int64_t>(tileInterchange.begin(),
tileInterchange.begin() +
rootOp.getNumLoops());
// Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
// It has to be a permutation since the tiling cannot tile the same loop
// dimension multiple times.
if (!isPermutation(rootInterchange))
return notifyFailure(
"expect the tile interchange permutes the root loops");
// Tile `rootOp` and fuse its producers.
FailureOr<TileLoopNest> tileLoopNest =
tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange);
if (failed(tileLoopNest))
return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
// Replace all uses of the tiled loop operation.
rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgTileAndFuseTensorOpsPass() {
return std::make_unique<LinalgTileAndFuseTensorOps>();
}