llvm-project/mlir/lib/Transforms/PipelineDataTransfer.cpp

370 lines
14 KiB
C++

//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to pipeline data transfers.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "affine-pipeline-data-transfer"
using namespace mlir;
namespace {
struct PipelineDataTransfer
: public AffinePipelineDataTransferBase<PipelineDataTransfer> {
void runOnFunction() override;
void runOnAffineForOp(AffineForOp forOp);
std::vector<AffineForOp> forOps;
};
} // namespace
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
std::unique_ptr<OperationPass<FuncOp>> mlir::createPipelineDataTransferPass() {
return std::make_unique<PipelineDataTransfer>();
}
// Returns the position of the tag memref operand given a DMA operation.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO
static unsigned getTagMemRefPos(Operation &dmaOp) {
assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
return dmaStartOp.getTagMemRefOperandIndex();
}
// First operand for a dma finish operation.
return 0;
}
/// Doubles the buffer of the supplied memref on the specified 'affine.for'
/// operation by adding a leading dimension of size two to the memref.
/// Replaces all uses of the old memref by the new one while indexing the newly
/// added dimension by the loop IV of the specified 'affine.for' operation
/// modulo 2. Returns false if such a replacement cannot be performed.
static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
auto *forBody = forOp.getBody();
OpBuilder bInner(forBody, forBody->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
// Add the leading dimension in the shape for the double buffer.
ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
newShape[0] = 2;
std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
};
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
auto newMemRefType = doubleShape(oldMemRefType);
// The double buffer is allocated right before 'forOp'.
OpBuilder bOuter(forOp);
// Put together alloc operands for any dynamic dimensions of the memref.
SmallVector<Value, 4> allocOperands;
for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
if (dim.value() == ShapedType::kDynamicSize)
allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
forOp.getLoc(), oldMemRef, dim.index()));
}
// Create and place the alloc right before the 'affine.for' operation.
Value newMemRef = bOuter.create<memref::AllocOp>(
forOp.getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
int64_t step = forOp.getStep();
auto modTwoMap =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
forOp.getInductionVar());
// replaceAllMemRefUsesWith will succeed unless the forOp body has
// non-dereferencing uses of the memref (dealloc's are fine though).
if (failed(replaceAllMemRefUsesWith(
oldMemRef, newMemRef,
/*extraIndices=*/{ivModTwoOp},
/*indexRemap=*/AffineMap(),
/*extraOperands=*/{},
/*symbolOperands=*/{},
/*domOpFilter=*/&*forOp.getBody()->begin()))) {
LLVM_DEBUG(
forOp.emitError("memref replacement for double buffering failed"));
ivModTwoOp.erase();
return false;
}
// Insert the dealloc op right after the for loop.
bOuter.setInsertionPointAfter(forOp);
bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
return true;
}
/// Returns success if the IR is in a valid state.
void PipelineDataTransfer::runOnFunction() {
// Do a post order walk so that inner loop DMAs are processed first. This is
// necessary since 'affine.for' operations nested within would otherwise
// become invalid (erased) when the outer loop is pipelined (the pipelined one
// gets deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
for (auto forOp : forOps)
runOnAffineForOp(forOp);
}
// Check if tags of the dma start op and dma wait op match.
static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
if (startOp.getTagMemRef() != waitOp.getTagMemRef())
return false;
auto startIndices = startOp.getTagIndices();
auto waitIndices = waitOp.getTagIndices();
// Both of these have the same number of indices since they correspond to the
// same tag memref.
for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
e = startIndices.end();
it != e; ++it, ++wIt) {
// Keep it simple for now, just checking if indices match.
// TODO: this would in general need to check if there is no
// intervening write writing to the same tag location, i.e., memory last
// write/data flow analysis. This is however sufficient/powerful enough for
// now since the DMA generation pass or the input for it will always have
// start/wait with matching tags (same SSA operand indices).
if (*it != *wIt)
return false;
}
return true;
}
// Identify matching DMA start/finish operations to overlap computation with.
static void findMatchingStartFinishInsts(
AffineForOp forOp,
SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
// Collect outgoing DMA operations - needed to check for dependences below.
SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
for (auto &op : *forOp.getBody()) {
auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
for (auto &op : *forOp.getBody()) {
// Collect DMA finish operations.
if (isa<AffineDmaWaitOp>(op)) {
dmaFinishInsts.push_back(&op);
continue;
}
auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
if (!dmaStartOp)
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO: handle outgoing DMA pipelining.
if (!dmaStartOp.isDestMemorySpaceFaster())
continue;
// Check for dependence with outgoing DMAs. Doing this conservatively.
// TODO: use the dependence analysis to check for
// dependences between an incoming and outgoing DMA in the same iteration.
auto *it = outgoingDmaOps.begin();
for (; it != outgoingDmaOps.end(); ++it) {
if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
break;
}
if (it != outgoingDmaOps.end())
continue;
// We only double buffer if the buffer is not live out of loop.
auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
bool escapingUses = false;
for (auto *user : memref.getUsers()) {
// We can double buffer regardless of dealloc's outside the loop.
if (isa<memref::DeallocOp>(user))
continue;
if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
break;
}
}
if (!escapingUses)
dmaStartInsts.push_back(&op);
}
// For each start operation, we look for a matching finish operation.
for (auto *dmaStartOp : dmaStartInsts) {
for (auto *dmaFinishOp : dmaFinishInsts) {
if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
cast<AffineDmaWaitOp>(dmaFinishOp))) {
startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
break;
}
}
}
}
/// Overlap DMA transfers with computation in this loop. If successful,
/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
auto mayBeConstTripCount = getConstantTripCount(forOp);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
return;
}
SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
findMatchingStartFinishInsts(forOp, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
return;
}
// Double the buffers for the higher memory space memref's.
// Identify memref's to replace by scanning through all DMA start
// operations. A DMA start operation has two memref's - the one from the
// higher level of memory hierarchy is the one to double buffer.
// TODO: check whether double-buffering is even necessary.
// TODO: make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
for (auto &pair : startWaitPairs) {
auto *dmaStartOp = pair.first;
Value oldMemRef = dmaStartOp->getOperand(
cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs()
<< "double buffering failed for" << dmaStartOp << "\n";);
// IR still valid and semantically correct.
return;
}
// If the old memref has no more uses, remove its 'dead' alloc if it was
// alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
// operation could have been used on it if it was dynamically shaped in
// order to create the double buffer above.)
// '-canonicalize' does this in a more general way, but we'll anyway do the
// simple/common case so that the output / test cases looks clear.
if (auto *allocOp = oldMemRef.getDefiningOp()) {
if (oldMemRef.use_empty()) {
allocOp->erase();
} else if (oldMemRef.hasOneUse()) {
if (auto dealloc =
dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
dealloc.erase();
allocOp->erase();
}
}
}
}
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
auto *dmaFinishOp = pair.second;
Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
if (!doubleBuffer(oldTagMemRef, forOp)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return;
}
// If the old tag has no uses or a single dealloc use, remove it.
// (canonicalization handles more complex cases).
if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
if (oldTagMemRef.use_empty()) {
tagAllocOp->erase();
} else if (oldTagMemRef.hasOneUse()) {
if (auto dealloc =
dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
dealloc.erase();
tagAllocOp->erase();
}
}
}
}
// Double buffering would have invalidated all the old DMA start/wait insts.
startWaitPairs.clear();
findMatchingStartFinishInsts(forOp, startWaitPairs);
// Store shift for operation for later lookup for AffineApplyOp's.
DenseMap<Operation *, unsigned> instShiftMap;
for (auto &pair : startWaitPairs) {
auto *dmaStartOp = pair.first;
assert(isa<AffineDmaStartOp>(dmaStartOp));
instShiftMap[dmaStartOp] = 0;
// Set shifts for DMA start op's affine operand computation slices to 0.
SmallVector<AffineApplyOp, 4> sliceOps;
mlir::createAffineComputationSlice(dmaStartOp, &sliceOps);
if (!sliceOps.empty()) {
for (auto sliceOp : sliceOps) {
instShiftMap[sliceOp.getOperation()] = 0;
}
} else {
// If a slice wasn't created, the reachable affine.apply op's from its
// operands are the ones that go with it.
SmallVector<Operation *, 4> affineApplyInsts;
SmallVector<Value, 4> operands(dmaStartOp->getOperands());
getReachableAffineApplyOps(operands, affineApplyInsts);
for (auto *op : affineApplyInsts) {
instShiftMap[op] = 0;
}
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
for (auto &op : forOp.getBody()->without_terminator())
if (instShiftMap.find(&op) == instShiftMap.end())
instShiftMap[&op] = 1;
// Get shifts stored in map.
SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
unsigned s = 0;
for (auto &op : forOp.getBody()->without_terminator()) {
assert(instShiftMap.find(&op) != instShiftMap.end());
shifts[s++] = instShiftMap[&op];
// Tagging operations with shifts for debugging purposes.
LLVM_DEBUG({
OpBuilder b(&op);
op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
});
}
if (!isOpwiseShiftValid(forOp, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return;
}
if (failed(affineForOpBodySkew(forOp, shifts))) {
LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
return;
}
}