2018-09-29 03:17:26 +08:00
|
|
|
//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file implements a pass to pipeline data transfers.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
|
2018-10-13 05:54:54 +08:00
|
|
|
#include "mlir/Analysis/AffineAnalysis.h"
|
2018-10-19 02:14:26 +08:00
|
|
|
#include "mlir/Analysis/LoopAnalysis.h"
|
|
|
|
#include "mlir/Analysis/Utils.h"
|
2018-10-05 08:15:30 +08:00
|
|
|
#include "mlir/IR/Builders.h"
|
2018-10-19 02:14:26 +08:00
|
|
|
#include "mlir/IR/StmtVisitor.h"
|
2018-10-11 05:23:30 +08:00
|
|
|
#include "mlir/StandardOps/StandardOps.h"
|
2018-09-29 03:17:26 +08:00
|
|
|
#include "mlir/Transforms/LoopUtils.h"
|
|
|
|
#include "mlir/Transforms/Pass.h"
|
2018-10-05 08:15:30 +08:00
|
|
|
#include "mlir/Transforms/Utils.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
2018-10-19 02:14:26 +08:00
|
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
|
|
|
|
#define DEBUG_TYPE "pipeline-data-transfer"
|
2018-09-29 03:17:26 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
2018-10-26 07:58:08 +08:00
|
|
|
struct PipelineDataTransfer : public FunctionPass,
|
2018-10-19 02:14:26 +08:00
|
|
|
StmtWalker<PipelineDataTransfer> {
|
2018-09-29 03:17:26 +08:00
|
|
|
PassResult runOnMLFunction(MLFunction *f) override;
|
2018-10-19 02:14:26 +08:00
|
|
|
PassResult runOnForStmt(ForStmt *forStmt);
|
|
|
|
|
|
|
|
// Collect all 'for' statements.
|
|
|
|
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
|
|
|
|
std::vector<ForStmt *> forStmts;
|
2018-09-29 03:17:26 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
/// Creates a pass to pipeline explicit movement of data across levels of the
|
|
|
|
/// memory hierarchy.
|
2018-10-26 07:58:08 +08:00
|
|
|
FunctionPass *mlir::createPipelineDataTransferPass() {
|
2018-09-29 03:17:26 +08:00
|
|
|
return new PipelineDataTransfer();
|
|
|
|
}
|
|
|
|
|
2018-10-05 08:15:30 +08:00
|
|
|
// Returns the position of the tag memref operand given a DMA statement.
|
|
|
|
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
|
|
|
// added. TODO(b/117228571)
|
2018-10-10 06:04:27 +08:00
|
|
|
static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
|
2018-10-20 00:07:58 +08:00
|
|
|
assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>());
|
|
|
|
if (dmaStmt.isa<DmaStartOp>()) {
|
2018-10-05 08:15:30 +08:00
|
|
|
// Second to last operand.
|
|
|
|
return dmaStmt.getNumOperands() - 2;
|
|
|
|
}
|
|
|
|
// First operand for a dma finish statement.
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2018-10-10 06:04:27 +08:00
|
|
|
/// Doubles the buffer of the supplied memref while replacing all uses of the
|
|
|
|
/// old memref. Returns false if such a replacement cannot be performed.
|
2018-10-19 02:14:26 +08:00
|
|
|
static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
|
2018-10-05 08:15:30 +08:00
|
|
|
MLFuncBuilder bInner(forStmt, forStmt->begin());
|
|
|
|
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
|
|
|
|
|
|
|
// Doubles the shape with a leading dimension extent of 2.
|
2018-10-19 02:14:26 +08:00
|
|
|
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
|
2018-10-05 08:15:30 +08:00
|
|
|
// Add the leading dimension in the shape for the double buffer.
|
2018-10-19 02:14:26 +08:00
|
|
|
ArrayRef<int> shape = oldMemRefType->getShape();
|
2018-10-05 08:15:30 +08:00
|
|
|
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
|
|
|
|
shapeSizes.insert(shapeSizes.begin(), 2);
|
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
auto *newMemRefType =
|
|
|
|
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
|
|
|
|
oldMemRefType->getMemorySpace());
|
2018-10-05 08:15:30 +08:00
|
|
|
return newMemRefType;
|
|
|
|
};
|
|
|
|
|
|
|
|
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
|
|
|
|
|
|
|
|
// Create and place the alloc at the top level.
|
2018-10-13 05:54:54 +08:00
|
|
|
MLFuncBuilder topBuilder(forStmt->getFunction());
|
2018-10-05 08:15:30 +08:00
|
|
|
auto *newMemRef = cast<MLValue>(
|
|
|
|
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
|
|
|
|
->getResult());
|
|
|
|
|
2018-10-09 01:20:25 +08:00
|
|
|
auto d0 = bInner.getAffineDimExpr(0);
|
2018-10-10 07:39:24 +08:00
|
|
|
auto modTwoMap =
|
|
|
|
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
|
2018-10-05 08:15:30 +08:00
|
|
|
auto ivModTwoOp =
|
|
|
|
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
|
2018-10-13 05:54:54 +08:00
|
|
|
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
|
2018-10-19 02:14:26 +08:00
|
|
|
cast<MLValue>(ivModTwoOp->getResult(0)))) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
|
|
<< "memref replacement for double buffering failed\n";);
|
2018-10-22 10:53:10 +08:00
|
|
|
ivModTwoOp->getOperation()->erase();
|
2018-10-05 08:15:30 +08:00
|
|
|
return false;
|
2018-10-19 02:14:26 +08:00
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
/// Returns false if this succeeds on at least one 'for' stmt.
|
2018-09-29 03:17:26 +08:00
|
|
|
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
|
2018-10-19 02:14:26 +08:00
|
|
|
// Do a post order walk so that inner loop DMAs are processed first. This is
|
|
|
|
// necessary since 'for' statements nested within would otherwise become
|
|
|
|
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
|
|
|
|
// deleted and replaced by a prologue, a new steady-state loop and an
|
|
|
|
// epilogue).
|
|
|
|
forStmts.clear();
|
|
|
|
walkPostOrder(f);
|
|
|
|
bool ret = true;
|
|
|
|
for (auto *forStmt : forStmts) {
|
|
|
|
ret = ret & runOnForStmt(forStmt);
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
2018-10-19 02:14:26 +08:00
|
|
|
return ret ? failure() : success();
|
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
// Check if tags of the dma start op and dma wait op match.
|
|
|
|
static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
|
|
|
|
OpPointer<DmaWaitOp> waitOp) {
|
|
|
|
if (startOp->getTagMemRef() != waitOp->getTagMemRef())
|
|
|
|
return false;
|
|
|
|
auto startIndices = startOp->getTagIndices();
|
|
|
|
auto waitIndices = waitOp->getTagIndices();
|
|
|
|
// Both of these have the same number of indices since they correspond to the
|
|
|
|
// same tag memref.
|
|
|
|
for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
|
|
|
|
e = startIndices.end();
|
|
|
|
it != e; ++it, ++wIt) {
|
|
|
|
// Keep it simple for now, just checking if indices match.
|
|
|
|
// TODO(mlir-team): this would in general need to check if there is no
|
|
|
|
// intervening write writing to the same tag location, i.e., memory last
|
|
|
|
// write/data flow analysis. This is however sufficient/powerful enough for
|
|
|
|
// now since the DMA generation pass or the input for it will always have
|
|
|
|
// start/wait with matching tags (same SSA operand indices).
|
|
|
|
if (*it != *wIt)
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
2018-09-29 03:17:26 +08:00
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
// Identify matching DMA start/finish statements to overlap computation with.
|
|
|
|
static void findMatchingStartFinishStmts(
|
|
|
|
ForStmt *forStmt,
|
|
|
|
SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>>
|
|
|
|
&startWaitPairs) {
|
|
|
|
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
|
2018-10-05 08:15:30 +08:00
|
|
|
for (auto &stmt : *forStmt) {
|
|
|
|
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
|
|
|
if (!opStmt)
|
|
|
|
continue;
|
2018-10-19 02:14:26 +08:00
|
|
|
// Collect DMA finish statements.
|
2018-10-20 00:07:58 +08:00
|
|
|
if (opStmt->isa<DmaWaitOp>()) {
|
2018-10-05 08:15:30 +08:00
|
|
|
dmaFinishStmts.push_back(opStmt);
|
2018-10-19 02:14:26 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
OpPointer<DmaStartOp> dmaStartOp;
|
2018-10-20 00:07:58 +08:00
|
|
|
if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>()))
|
2018-10-19 02:14:26 +08:00
|
|
|
continue;
|
|
|
|
// Only DMAs incoming into higher memory spaces.
|
|
|
|
// TODO(bondhugula): outgoing DMAs.
|
|
|
|
if (!dmaStartOp->isDestMemorySpaceFaster())
|
|
|
|
continue;
|
|
|
|
|
|
|
|
// We only double buffer if the buffer is not live out of loop.
|
|
|
|
const MLValue *memref =
|
|
|
|
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
|
|
|
|
bool escapingUses = false;
|
|
|
|
for (const auto &use : memref->getUses()) {
|
|
|
|
if (!dominates(*forStmt, *use.getOwner())) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
|
|
<< "can't pipeline: buffer is live out of loop\n";);
|
|
|
|
escapingUses = true;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!escapingUses)
|
|
|
|
dmaStartStmts.push_back(opStmt);
|
|
|
|
}
|
|
|
|
|
|
|
|
// For each start statement, we look for a matching finish statement.
|
|
|
|
for (auto *dmaStartStmt : dmaStartStmts) {
|
|
|
|
for (auto *dmaFinishStmt : dmaFinishStmts) {
|
2018-10-20 00:07:58 +08:00
|
|
|
if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(),
|
|
|
|
dmaFinishStmt->cast<DmaWaitOp>())) {
|
2018-10-19 02:14:26 +08:00
|
|
|
startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
|
|
|
|
break;
|
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
|
|
|
}
|
2018-10-19 02:14:26 +08:00
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
/// Overlap DMA transfers with computation in this loop. If successful,
|
|
|
|
/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
|
|
|
|
/// inserted right before where it was.
|
|
|
|
PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|
|
|
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
|
|
|
|
if (!mayBeConstTripCount.hasValue()) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
|
2018-10-23 04:44:31 +08:00
|
|
|
return success();
|
2018-10-19 02:14:26 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
|
|
|
|
findMatchingStartFinishStmts(forStmt, startWaitPairs);
|
|
|
|
|
|
|
|
if (startWaitPairs.empty()) {
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
|
2018-10-23 04:44:31 +08:00
|
|
|
return success();
|
2018-10-19 02:14:26 +08:00
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
|
|
|
|
// Double the buffers for the higher memory space memref's.
|
2018-10-19 02:14:26 +08:00
|
|
|
// Identify memref's to replace by scanning through all DMA start statements.
|
|
|
|
// A DMA start statement has two memref's - the one from the higher level of
|
|
|
|
// memory hierarchy is the one to double buffer.
|
2018-10-05 08:15:30 +08:00
|
|
|
// TODO(bondhugula): check whether double-buffering is even necessary.
|
|
|
|
// TODO(bondhugula): make this work with different layouts: assuming here that
|
|
|
|
// the dimension we are adding here for the double buffering is the outermost
|
|
|
|
// dimension.
|
2018-10-19 02:14:26 +08:00
|
|
|
for (auto &pair : startWaitPairs) {
|
|
|
|
auto *dmaStartStmt = pair.first;
|
|
|
|
const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
|
2018-10-20 00:07:58 +08:00
|
|
|
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()));
|
2018-10-13 05:54:54 +08:00
|
|
|
if (!doubleBuffer(oldMemRef, forStmt)) {
|
2018-10-19 02:14:26 +08:00
|
|
|
// Normally, double buffering should not fail because we already checked
|
|
|
|
// that there are no uses outside.
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
|
|
|
|
LLVM_DEBUG(dmaStartStmt->dump());
|
2018-10-23 04:44:31 +08:00
|
|
|
// IR still in a valid state.
|
|
|
|
return success();
|
2018-10-13 05:54:54 +08:00
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
// Double the buffers for tag memrefs.
|
|
|
|
for (auto &pair : startWaitPairs) {
|
|
|
|
const auto *dmaFinishStmt = pair.second;
|
|
|
|
const MLValue *oldTagMemRef = cast<MLValue>(
|
2018-10-05 08:15:30 +08:00
|
|
|
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
|
2018-10-13 05:54:54 +08:00
|
|
|
if (!doubleBuffer(oldTagMemRef, forStmt)) {
|
2018-10-19 02:14:26 +08:00
|
|
|
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
2018-10-23 04:44:31 +08:00
|
|
|
return success();
|
2018-10-13 05:54:54 +08:00
|
|
|
}
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
// Double buffering would have invalidated all the old DMA start/wait stmts.
|
|
|
|
startWaitPairs.clear();
|
|
|
|
findMatchingStartFinishStmts(forStmt, startWaitPairs);
|
|
|
|
|
2018-10-05 08:15:30 +08:00
|
|
|
// Store delay for statement for later lookup for AffineApplyOp's.
|
2018-10-19 02:14:26 +08:00
|
|
|
DenseMap<const Statement *, unsigned> stmtDelayMap;
|
|
|
|
for (auto &pair : startWaitPairs) {
|
|
|
|
auto *dmaStartStmt = pair.first;
|
2018-10-20 00:07:58 +08:00
|
|
|
assert(dmaStartStmt->isa<DmaStartOp>());
|
2018-10-19 02:14:26 +08:00
|
|
|
stmtDelayMap[dmaStartStmt] = 0;
|
|
|
|
// Set shifts for DMA start stmt's affine operand computation slices to 0.
|
|
|
|
if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
|
|
|
|
stmtDelayMap[slice] = 0;
|
2018-10-13 05:54:54 +08:00
|
|
|
} else {
|
2018-10-19 02:14:26 +08:00
|
|
|
// If a slice wasn't created, the reachable affine_apply op's from its
|
|
|
|
// operands are the ones that go with it.
|
|
|
|
SmallVector<OperationStmt *, 4> affineApplyStmts;
|
|
|
|
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
|
|
|
|
getReachableAffineApplyOps(operands, affineApplyStmts);
|
|
|
|
for (const auto *stmt : affineApplyStmts) {
|
|
|
|
stmtDelayMap[stmt] = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Everything else (including compute ops and dma finish) are shifted by one.
|
|
|
|
for (const auto &stmt : *forStmt) {
|
|
|
|
if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) {
|
|
|
|
stmtDelayMap[&stmt] = 1;
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get delays stored in map.
|
|
|
|
std::vector<uint64_t> delays(forStmt->getStatements().size());
|
|
|
|
unsigned s = 0;
|
|
|
|
for (const auto &stmt : *forStmt) {
|
2018-10-19 02:14:26 +08:00
|
|
|
assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
|
|
|
|
delays[s++] = stmtDelayMap[&stmt];
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
2018-09-29 03:17:26 +08:00
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
if (!isStmtwiseShiftValid(*forStmt, delays)) {
|
2018-10-23 04:44:31 +08:00
|
|
|
// Violates dependences.
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
|
|
|
|
return success();
|
2018-10-05 08:15:30 +08:00
|
|
|
}
|
2018-09-29 03:17:26 +08:00
|
|
|
|
2018-10-13 05:54:54 +08:00
|
|
|
if (stmtBodySkew(forStmt, delays)) {
|
2018-10-23 04:44:31 +08:00
|
|
|
LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
|
|
|
|
return success();
|
2018-10-13 05:54:54 +08:00
|
|
|
}
|
2018-09-29 03:17:26 +08:00
|
|
|
|
2018-10-19 02:14:26 +08:00
|
|
|
return success();
|
2018-09-29 03:17:26 +08:00
|
|
|
}
|