forked from OSchip/llvm-project
DMA placement update - hoist loops invariant DMAs
- hoist DMAs past all loops immediately surrounding the region that the latter is invariant on - do this at DMA generation time itself PiperOrigin-RevId: 234628447
This commit is contained in:
parent
25016dc4c6
commit
5021dc4fa0
|
@ -291,11 +291,11 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
}
|
||||
|
||||
const FlatAffineConstraints *cst = region.getConstraints();
|
||||
// 'outerIVs' holds the values that this memory region is symbolic/paramteric
|
||||
// on; this would correspond to loop IVs surrounding the level at which the
|
||||
// DMA generation is being done.
|
||||
SmallVector<Value *, 8> outerIVs;
|
||||
cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
|
||||
// 'regionSymbols' hold values that this memory region is symbolic/paramteric
|
||||
// on; these typically include loop IVs surrounding the level at which the DMA
|
||||
// generation is being done or other valid symbols in MLIR.
|
||||
SmallVector<Value *, 8> regionSymbols;
|
||||
cst->getIdValues(rank, cst->getNumIds(), ®ionSymbols);
|
||||
|
||||
// Construct the index expressions for the fast memory buffer. The index
|
||||
// expression for a particular dimension of the fast buffer is obtained by
|
||||
|
@ -331,7 +331,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
// corresponding dimension on the memory region (stored in 'offset').
|
||||
auto map = top.getAffineMap(
|
||||
cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset, {});
|
||||
memIndices.push_back(b->create<AffineApplyOp>(loc, map, outerIVs));
|
||||
memIndices.push_back(b->create<AffineApplyOp>(loc, map, regionSymbols));
|
||||
}
|
||||
// The fast buffer is DMAed into at location zero; addressing is relative.
|
||||
bufIndices.push_back(zeroIndex);
|
||||
|
@ -377,7 +377,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
SmallVector<StrideInfo, 4> strideInfos;
|
||||
getMultiLevelStrides(region, fastBufferShape, &strideInfos);
|
||||
|
||||
// TODO(bondhugula): use all stride level once DmaStartOp is extended for
|
||||
// TODO(bondhugula): use all stride levels once DmaStartOp is extended for
|
||||
// multi-level strides.
|
||||
if (strideInfos.size() > 1) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
|
||||
|
@ -437,13 +437,14 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
SmallVector<AffineExpr, 4> remapExprs;
|
||||
remapExprs.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; i++) {
|
||||
// The starting operands of indexRemap will be outerIVs (the loops
|
||||
// surrounding the depth at which this DMA is being done); then those
|
||||
// corresponding to the memref's original indices follow.
|
||||
auto dimExpr = b->getAffineDimExpr(outerIVs.size() + i);
|
||||
// The starting operands of indexRemap will be regionSymbols (the symbols on
|
||||
// which the memref region is parametric); then those corresponding to
|
||||
// the memref's original indices follow.
|
||||
auto dimExpr = b->getAffineDimExpr(regionSymbols.size() + i);
|
||||
remapExprs.push_back(dimExpr - offsets[i]);
|
||||
}
|
||||
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
|
||||
auto indexRemap =
|
||||
b->getAffineMap(regionSymbols.size() + rank, 0, remapExprs, {});
|
||||
|
||||
// Record the begin since it may be invalidated by memref replacement.
|
||||
Block::iterator prev;
|
||||
|
@ -454,7 +455,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
// *Only* those uses within the range [begin, end) of 'block' are replaced.
|
||||
replaceAllMemRefUsesWith(memref, fastMemRef,
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/outerIVs,
|
||||
/*extraOperands=*/regionSymbols,
|
||||
/*domInstFilter=*/&*begin,
|
||||
/*postDomInstFilter=*/&*postDomFilter);
|
||||
|
||||
|
@ -544,6 +545,44 @@ bool DmaGeneration::runOnBlock(Block *block) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Given a memref region, determine the lowest depth at which transfers can be
|
||||
/// placed for it, and return the corresponding block, start and end positions
|
||||
/// in the block for placing incoming (read) and outgoing (write) DMAs
|
||||
/// respectively. The lowest depth depends on whether the region being accessed
|
||||
/// is invariant with respect to one or more immediately surrounding loops.
|
||||
static void findHighestBlockForPlacement(
|
||||
const MemRefRegion ®ion, const Block &block,
|
||||
const Block::iterator &begin, const Block::iterator &end,
|
||||
Block **dmaPlacementBlock, Block::iterator *dmaPlacementReadStart,
|
||||
Block::iterator *dmaPlacementWriteStart) {
|
||||
const auto *cst = region.getConstraints();
|
||||
SmallVector<Value *, 4> symbols;
|
||||
cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
|
||||
|
||||
SmallVector<OpPointer<AffineForOp>, 4> enclosingFors;
|
||||
getLoopIVs(*block.begin(), &enclosingFors);
|
||||
// Walk up loop parents till we find an IV on which this region is
|
||||
// symbolic/variant.
|
||||
auto it = enclosingFors.rbegin();
|
||||
for (auto e = enclosingFors.rend(); it != e; ++it) {
|
||||
// TODO(bondhugula): also need to be checking this for regions symbols that
|
||||
// aren't loop IVs, whether we are within their resp. defs' dominance scope.
|
||||
if (llvm::is_contained(symbols, (*it)->getInductionVar()))
|
||||
break;
|
||||
}
|
||||
|
||||
if (it != enclosingFors.rbegin()) {
|
||||
auto lastInvariantIV = *std::prev(it);
|
||||
*dmaPlacementReadStart = Block::iterator(lastInvariantIV->getInstruction());
|
||||
*dmaPlacementWriteStart = std::next(*dmaPlacementReadStart);
|
||||
*dmaPlacementBlock = lastInvariantIV->getInstruction()->getBlock();
|
||||
} else {
|
||||
*dmaPlacementReadStart = *const_cast<Block::iterator *>(&begin);
|
||||
*dmaPlacementWriteStart = *const_cast<Block::iterator *>(&end);
|
||||
*dmaPlacementBlock = const_cast<Block *>(&block);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates DMAs for a contiguous sequence of instructions in `block` in the
|
||||
/// iterator range [begin, end). Returns the total size of the DMA buffers used.
|
||||
// Since we generate alloc's and dealloc's for all DMA buffers (before and
|
||||
|
@ -562,6 +601,8 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
|
|||
// surrounding the region of this block.
|
||||
unsigned dmaDepth = getNestingDepth(*begin);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Generating DMAs at depth " << dmaDepth << "\n");
|
||||
|
||||
readRegions.clear();
|
||||
writeRegions.clear();
|
||||
fastBufferMap.clear();
|
||||
|
@ -663,13 +704,25 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
|
|||
[&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
|
||||
®ions) {
|
||||
for (const auto ®ionEntry : regions) {
|
||||
// For each region, hoist DMA transfer past all invariant 'for's.
|
||||
Block::iterator dmaPlacementReadStart, dmaPlacementWriteStart;
|
||||
Block *dmaPlacementBlock;
|
||||
findHighestBlockForPlacement(
|
||||
*regionEntry.second, *block, begin, end, &dmaPlacementBlock,
|
||||
&dmaPlacementReadStart, &dmaPlacementWriteStart);
|
||||
|
||||
uint64_t sizeInBytes;
|
||||
Block::iterator nBegin, nEnd;
|
||||
bool iRet = generateDma(*regionEntry.second, block, begin, end,
|
||||
bool iRet = generateDma(*regionEntry.second, dmaPlacementBlock,
|
||||
dmaPlacementReadStart, dmaPlacementWriteStart,
|
||||
&sizeInBytes, &nBegin, &nEnd);
|
||||
if (iRet) {
|
||||
// dmaPlacmentStart/End (or begin/end) may be invalidated; use
|
||||
// nBegin, nEnd to reset.
|
||||
if (dmaPlacementBlock == block) {
|
||||
begin = nBegin;
|
||||
end = nEnd;
|
||||
}
|
||||
totalDmaBuffersSizeInBytes += sizeInBytes;
|
||||
}
|
||||
ret = ret & iRet;
|
||||
|
|
|
@ -514,3 +514,55 @@ func @load_store_same_memref(%arg0: memref<256x1024xf32>) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// This a 3-d loop nest tiled by 4 x 4 x 4. Under %i, %j, %k, the size of a
|
||||
// tile of arg0, arg1, and arg2 accessed is 4 KB (each), i.e., 12 KB in total.
|
||||
// With fast mem capacity set to 16 KB, the DMAs if placed under %k will fit.
|
||||
// However, the region of arg2 accessed is invariant w.r.t the %k loop unlike
|
||||
// %arg0 and %arg1. So, its DMA can be hoisted one level up and placed under
|
||||
// %j, while the DMAs for arg0 and arg1 appear right under the %k loop.
|
||||
|
||||
#map0 = (d0) -> (d0)
|
||||
#map1 = (d0) -> (d0 + 4)
|
||||
// FAST-MEM-16KB-LABEL: func @simple_matmul
|
||||
func @simple_matmul(%arg0: memref<8x8xvector<64xf32>>, %arg1: memref<8x8xvector<64xf32>>, %arg2: memref<8x8xvector<64xf32>>) -> memref<8x8xvector<64xf32>> {
|
||||
for %i = 0 to 8 step 4 {
|
||||
for %j = 0 to 8 step 4 {
|
||||
for %k = 0 to 8 step 4 {
|
||||
for %ii = #map0(%i) to #map1(%i) {
|
||||
for %jj = #map0(%j) to #map1(%j) {
|
||||
for %kk = #map0(%k) to #map1(%k) {
|
||||
%5 = load %arg0[%ii, %kk] : memref<8x8xvector<64xf32>>
|
||||
%6 = load %arg1[%kk, %jj] : memref<8x8xvector<64xf32>>
|
||||
%7 = load %arg2[%ii, %jj] : memref<8x8xvector<64xf32>>
|
||||
%8 = mulf %5, %6 : vector<64xf32>
|
||||
%9 = addf %7, %8 : vector<64xf32>
|
||||
store %9, %arg2[%ii, %jj] : memref<8x8xvector<64xf32>>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return %arg2 : memref<8x8xvector<64xf32>>
|
||||
}
|
||||
// FAST-MEM-16KB: for %i0 = 0 to 8 step 4 {
|
||||
// FAST-MEM-16KB: for %i1 = 0 to 8 step 4 {
|
||||
// FAST-MEM-16KB: dma_start %arg2
|
||||
// FAST-MEM-16KB: dma_wait
|
||||
// FAST-MEM-16KB: for %i2 = 0 to 8 step 4 {
|
||||
// FAST-MEM-16KB: dma_start %arg0
|
||||
// FAST-MEM-16KB: dma_wait
|
||||
// FAST-MEM-16KB: dma_start %arg1
|
||||
// FAST-MEM-16KB: dma_wait
|
||||
// FAST-MEM-16KB: for %i3 = #map2(%i0) to #map3(%i0) {
|
||||
// FAST-MEM-16KB-NEXT: for %i4 = #map2(%i1) to #map3(%i1) {
|
||||
// FAST-MEM-16KB-NEXT: for %i5 = #map2(%i2) to #map3(%i2) {
|
||||
// FAST-MEM-16KB: }
|
||||
// FAST-MEM-16KB: }
|
||||
// FAST-MEM-16KB: }
|
||||
// FAST-MEM-16KB: }
|
||||
// FAST-MEM-16KB: dma_start %2[%c0, %c0], %arg2
|
||||
// FAST-MEM-16KB: dma_wait
|
||||
|
|
Loading…
Reference in New Issue