DMA placement update - hoist loops invariant DMAs

- hoist DMAs past all loops immediately surrounding the region that the latter
  is invariant on - do this at DMA generation time itself

PiperOrigin-RevId: 234628447
This commit is contained in:
Uday Bondhugula 2019-02-19 10:33:41 -08:00 committed by jpienaar
parent 25016dc4c6
commit 5021dc4fa0
2 changed files with 121 additions and 16 deletions

View File

@ -291,11 +291,11 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
}
const FlatAffineConstraints *cst = region.getConstraints();
// 'outerIVs' holds the values that this memory region is symbolic/paramteric
// on; this would correspond to loop IVs surrounding the level at which the
// DMA generation is being done.
SmallVector<Value *, 8> outerIVs;
cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
// 'regionSymbols' hold values that this memory region is symbolic/paramteric
// on; these typically include loop IVs surrounding the level at which the DMA
// generation is being done or other valid symbols in MLIR.
SmallVector<Value *, 8> regionSymbols;
cst->getIdValues(rank, cst->getNumIds(), &regionSymbols);
// Construct the index expressions for the fast memory buffer. The index
// expression for a particular dimension of the fast buffer is obtained by
@ -331,7 +331,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
// corresponding dimension on the memory region (stored in 'offset').
auto map = top.getAffineMap(
cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {});
memIndices.push_back(b->create<AffineApplyOp>(loc, map, outerIVs));
memIndices.push_back(b->create<AffineApplyOp>(loc, map, regionSymbols));
}
// The fast buffer is DMAed into at location zero; addressing is relative.
bufIndices.push_back(zeroIndex);
@ -377,7 +377,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
SmallVector<StrideInfo, 4> strideInfos;
getMultiLevelStrides(region, fastBufferShape, &strideInfos);
// TODO(bondhugula): use all stride level once DmaStartOp is extended for
// TODO(bondhugula): use all stride levels once DmaStartOp is extended for
// multi-level strides.
if (strideInfos.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
@ -437,13 +437,14 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
SmallVector<AffineExpr, 4> remapExprs;
remapExprs.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
// The starting operands of indexRemap will be outerIVs (the loops
// surrounding the depth at which this DMA is being done); then those
// corresponding to the memref's original indices follow.
auto dimExpr = b->getAffineDimExpr(outerIVs.size() + i);
// The starting operands of indexRemap will be regionSymbols (the symbols on
// which the memref region is parametric); then those corresponding to
// the memref's original indices follow.
auto dimExpr = b->getAffineDimExpr(regionSymbols.size() + i);
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
auto indexRemap =
b->getAffineMap(regionSymbols.size() + rank, 0, remapExprs, {});
// Record the begin since it may be invalidated by memref replacement.
Block::iterator prev;
@ -454,7 +455,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
// *Only* those uses within the range [begin, end) of 'block' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*extraOperands=*/regionSymbols,
/*domInstFilter=*/&*begin,
/*postDomInstFilter=*/&*postDomFilter);
@ -544,6 +545,44 @@ bool DmaGeneration::runOnBlock(Block *block) {
return true;
}
/// Given a memref region, determine the lowest depth at which transfers can be
/// placed for it, and return the corresponding block, start and end positions
/// in the block for placing incoming (read) and outgoing (write) DMAs
/// respectively. The lowest depth depends on whether the region being accessed
/// is invariant with respect to one or more immediately surrounding loops.
static void findHighestBlockForPlacement(
const MemRefRegion &region, const Block &block,
const Block::iterator &begin, const Block::iterator &end,
Block **dmaPlacementBlock, Block::iterator *dmaPlacementReadStart,
Block::iterator *dmaPlacementWriteStart) {
const auto *cst = region.getConstraints();
SmallVector<Value *, 4> symbols;
cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
SmallVector<OpPointer<AffineForOp>, 4> enclosingFors;
getLoopIVs(*block.begin(), &enclosingFors);
// Walk up loop parents till we find an IV on which this region is
// symbolic/variant.
auto it = enclosingFors.rbegin();
for (auto e = enclosingFors.rend(); it != e; ++it) {
// TODO(bondhugula): also need to be checking this for regions symbols that
// aren't loop IVs, whether we are within their resp. defs' dominance scope.
if (llvm::is_contained(symbols, (*it)->getInductionVar()))
break;
}
if (it != enclosingFors.rbegin()) {
auto lastInvariantIV = *std::prev(it);
*dmaPlacementReadStart = Block::iterator(lastInvariantIV->getInstruction());
*dmaPlacementWriteStart = std::next(*dmaPlacementReadStart);
*dmaPlacementBlock = lastInvariantIV->getInstruction()->getBlock();
} else {
*dmaPlacementReadStart = *const_cast<Block::iterator *>(&begin);
*dmaPlacementWriteStart = *const_cast<Block::iterator *>(&end);
*dmaPlacementBlock = const_cast<Block *>(&block);
}
}
/// Generates DMAs for a contiguous sequence of instructions in `block` in the
/// iterator range [begin, end). Returns the total size of the DMA buffers used.
// Since we generate alloc's and dealloc's for all DMA buffers (before and
@ -562,6 +601,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
// surrounding the region of this block.
unsigned dmaDepth = getNestingDepth(*begin);
LLVM_DEBUG(llvm::dbgs() << "Generating DMAs at depth " << dmaDepth << "\n");
readRegions.clear();
writeRegions.clear();
fastBufferMap.clear();
@ -663,13 +704,25 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
[&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
&regions) {
for (const auto &regionEntry : regions) {
// For each region, hoist DMA transfer past all invariant 'for's.
Block::iterator dmaPlacementReadStart, dmaPlacementWriteStart;
Block *dmaPlacementBlock;
findHighestBlockForPlacement(
*regionEntry.second, *block, begin, end, &dmaPlacementBlock,
&dmaPlacementReadStart, &dmaPlacementWriteStart);
uint64_t sizeInBytes;
Block::iterator nBegin, nEnd;
bool iRet = generateDma(*regionEntry.second, block, begin, end,
bool iRet = generateDma(*regionEntry.second, dmaPlacementBlock,
dmaPlacementReadStart, dmaPlacementWriteStart,
&sizeInBytes, &nBegin, &nEnd);
if (iRet) {
begin = nBegin;
end = nEnd;
// dmaPlacmentStart/End (or begin/end) may be invalidated; use
// nBegin, nEnd to reset.
if (dmaPlacementBlock == block) {
begin = nBegin;
end = nEnd;
}
totalDmaBuffersSizeInBytes += sizeInBytes;
}
ret = ret & iRet;

View File

@ -514,3 +514,55 @@ func @load_store_same_memref(%arg0: memref<256x1024xf32>) {
}
return
}
// ----
// This a 3-d loop nest tiled by 4 x 4 x 4. Under %i, %j, %k, the size of a
// tile of arg0, arg1, and arg2 accessed is 4 KB (each), i.e., 12 KB in total.
// With fast mem capacity set to 16 KB, the DMAs if placed under %k will fit.
// However, the region of arg2 accessed is invariant w.r.t the %k loop unlike
// %arg0 and %arg1. So, its DMA can be hoisted one level up and placed under
// %j, while the DMAs for arg0 and arg1 appear right under the %k loop.
#map0 = (d0) -> (d0)
#map1 = (d0) -> (d0 + 4)
// FAST-MEM-16KB-LABEL: func @simple_matmul
func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector<64xf32>>, %arg2: memref<8x8xvector<64xf32>>) -> memref<8x8xvector<64xf32>> {
for %i = 0 to 8 step 4 {
for %j = 0 to 8 step 4 {
for %k = 0 to 8 step 4 {
for %ii = #map0(%i) to #map1(%i) {
for %jj = #map0(%j) to #map1(%j) {
for %kk = #map0(%k) to #map1(%k) {
%5 = load %arg0[%ii, %kk] : memref<8x8xvector<64xf32>>
%6 = load %arg1[%kk, %jj] : memref<8x8xvector<64xf32>>
%7 = load %arg2[%ii, %jj] : memref<8x8xvector<64xf32>>
%8 = mulf %5, %6 : vector<64xf32>
%9 = addf %7, %8 : vector<64xf32>
store %9, %arg2[%ii, %jj] : memref<8x8xvector<64xf32>>
}
}
}
}
}
}
return %arg2 : memref<8x8xvector<64xf32>>
}
// FAST-MEM-16KB: for %i0 = 0 to 8 step 4 {
// FAST-MEM-16KB: for %i1 = 0 to 8 step 4 {
// FAST-MEM-16KB: dma_start %arg2
// FAST-MEM-16KB: dma_wait
// FAST-MEM-16KB: for %i2 = 0 to 8 step 4 {
// FAST-MEM-16KB: dma_start %arg0
// FAST-MEM-16KB: dma_wait
// FAST-MEM-16KB: dma_start %arg1
// FAST-MEM-16KB: dma_wait
// FAST-MEM-16KB: for %i3 = #map2(%i0) to #map3(%i0) {
// FAST-MEM-16KB-NEXT: for %i4 = #map2(%i1) to #map3(%i1) {
// FAST-MEM-16KB-NEXT: for %i5 = #map2(%i2) to #map3(%i2) {
// FAST-MEM-16KB: }
// FAST-MEM-16KB: }
// FAST-MEM-16KB: }
// FAST-MEM-16KB: }
// FAST-MEM-16KB: dma_start %2[%c0, %c0], %arg2
// FAST-MEM-16KB: dma_wait