Remove allocations for memref's that become dead as a result of double

buffering in the auto DMA overlap pass.

This is done online in the pass.

PiperOrigin-RevId: 222313640
This commit is contained in:
Uday Bondhugula 2018-11-20 15:07:37 -08:00 committed by jpienaar
parent 431f08ba7f
commit b6c03917ad
2 changed files with 34 additions and 26 deletions

View File

@ -236,7 +236,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// dimension.
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()));
if (!doubleBuffer(oldMemRef, forStmt)) {
// Normally, double buffering should not fail because we already checked
@ -246,17 +246,27 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// IR still in a valid state.
return success();
}
// If the old memref has no more uses, remove its 'dead' alloc if it was
// alloc'ed (note: DMA buffers are rarely function live-in).
if (oldMemRef->use_empty())
if (auto *allocStmt = oldMemRef->getDefiningStmt())
allocStmt->erase();
}
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
const auto *dmaFinishStmt = pair.second;
const MLValue *oldTagMemRef = cast<MLValue>(
auto *dmaFinishStmt = pair.second;
MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
}
// If the old tag has no more uses, remove its 'dead' alloc if it was
// alloc'ed.
if (oldTagMemRef->use_empty())
if (auto *allocStmt = oldTagMemRef->getDefiningStmt())
allocStmt->erase();
}
// Double buffering would have invalidated all the old DMA start/wait stmts.

View File

@ -7,35 +7,33 @@ mlfunc @loop_nest_dma() {
// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32, 1>
// CHECK-NEXT: %2 = alloc() : memref<256xf32>
// CHECK-NEXT: %3 = alloc() : memref<32xf32, 1>
// CHECK-NEXT: %4 = alloc() : memref<1xf32>
// CHECK-NEXT: %c0_0 = constant 0 : index
// CHECK-NEXT: %c128 = constant 128 : index
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
// CHECK-NEXT: dma_start %2[%c0], %1[%5#0, %c0], %c128, %0[%5#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: %3 = affine_apply #map0(%c0)
// CHECK-NEXT: dma_start %2[%c0], %1[%3#0, %c0], %c128, %0[%3#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: for %i0 = 1 to 8 {
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
// CHECK-NEXT: dma_start %2[%i0], %1[%6#0, %i0], %c128, %0[%6#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: %7 = affine_apply #map1(%i0)
// CHECK-NEXT: %8 = affine_apply #map2(%7)
// CHECK-NEXT: %9 = affine_apply #map2(%7)
// CHECK-NEXT: dma_wait %0[%8, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1>
// CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32
// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1>
// CHECK-NEXT: %4 = affine_apply #map0(%i0)
// CHECK-NEXT: dma_start %2[%i0], %1[%4#0, %i0], %c128, %0[%4#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: %5 = affine_apply #map1(%i0)
// CHECK-NEXT: %6 = affine_apply #map2(%5)
// CHECK-NEXT: %7 = affine_apply #map2(%5)
// CHECK-NEXT: dma_wait %0[%6, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %8 = load %1[%7, %5] : memref<2x32xf32, 1>
// CHECK-NEXT: %9 = "compute"(%8) : (f32) -> f32
// CHECK-NEXT: store %9, %1[%7, %5] : memref<2x32xf32, 1>
// CHECK-NEXT: for %i1 = 0 to 128 {
// CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> ()
// CHECK-NEXT: "do_more_compute"(%5, %i1) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: %12 = affine_apply #map1(%c8)
// CHECK-NEXT: %13 = affine_apply #map2(%12)
// CHECK-NEXT: %14 = affine_apply #map2(%12)
// CHECK-NEXT: dma_wait %0[%13, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1>
// CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32
// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1>
// CHECK-NEXT: %10 = affine_apply #map1(%c8)
// CHECK-NEXT: %11 = affine_apply #map2(%10)
// CHECK-NEXT: %12 = affine_apply #map2(%10)
// CHECK-NEXT: dma_wait %0[%11, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %13 = load %1[%12, %10] : memref<2x32xf32, 1>
// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
// CHECK-NEXT: store %14, %1[%12, %10] : memref<2x32xf32, 1>
// CHECK-NEXT: for %i2 = 0 to 128 {
// CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> ()
// CHECK-NEXT: "do_more_compute"(%10, %i2) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
@ -109,7 +107,7 @@ mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>, #map0>, %arg1 : mem
// CHECK: dma_wait %2[
// epilogue for DMA overlap on %arg2
// CHECK: dma_wait %0[%37, %c0_2], %c256 : memref<2x2xi32>
// CHECK: dma_wait %0[%31, %c0_2], %c256 : memref<2x2xi32>
// Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested.
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[