forked from OSchip/llvm-project
Remove allocations for memref's that become dead as a result of double
buffering in the auto DMA overlap pass. This is done online in the pass. PiperOrigin-RevId: 222313640
This commit is contained in:
parent
431f08ba7f
commit
b6c03917ad
|
@ -236,7 +236,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
// dimension.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaStartStmt = pair.first;
|
||||
const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
|
||||
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
|
||||
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos()));
|
||||
if (!doubleBuffer(oldMemRef, forStmt)) {
|
||||
// Normally, double buffering should not fail because we already checked
|
||||
|
@ -246,17 +246,27 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
// IR still in a valid state.
|
||||
return success();
|
||||
}
|
||||
// If the old memref has no more uses, remove its 'dead' alloc if it was
|
||||
// alloc'ed (note: DMA buffers are rarely function live-in).
|
||||
if (oldMemRef->use_empty())
|
||||
if (auto *allocStmt = oldMemRef->getDefiningStmt())
|
||||
allocStmt->erase();
|
||||
}
|
||||
|
||||
// Double the buffers for tag memrefs.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
const auto *dmaFinishStmt = pair.second;
|
||||
const MLValue *oldTagMemRef = cast<MLValue>(
|
||||
auto *dmaFinishStmt = pair.second;
|
||||
MLValue *oldTagMemRef = cast<MLValue>(
|
||||
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
|
||||
if (!doubleBuffer(oldTagMemRef, forStmt)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
||||
return success();
|
||||
}
|
||||
// If the old tag has no more uses, remove its 'dead' alloc if it was
|
||||
// alloc'ed.
|
||||
if (oldTagMemRef->use_empty())
|
||||
if (auto *allocStmt = oldTagMemRef->getDefiningStmt())
|
||||
allocStmt->erase();
|
||||
}
|
||||
|
||||
// Double buffering would have invalidated all the old DMA start/wait stmts.
|
||||
|
|
|
@ -7,35 +7,33 @@ mlfunc @loop_nest_dma() {
|
|||
// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
|
||||
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %2 = alloc() : memref<256xf32>
|
||||
// CHECK-NEXT: %3 = alloc() : memref<32xf32, 1>
|
||||
// CHECK-NEXT: %4 = alloc() : memref<1xf32>
|
||||
// CHECK-NEXT: %c0_0 = constant 0 : index
|
||||
// CHECK-NEXT: %c128 = constant 128 : index
|
||||
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
|
||||
// CHECK-NEXT: dma_start %2[%c0], %1[%5#0, %c0], %c128, %0[%5#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
|
||||
// CHECK-NEXT: %3 = affine_apply #map0(%c0)
|
||||
// CHECK-NEXT: dma_start %2[%c0], %1[%3#0, %c0], %c128, %0[%3#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
|
||||
// CHECK-NEXT: for %i0 = 1 to 8 {
|
||||
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: dma_start %2[%i0], %1[%6#0, %i0], %c128, %0[%6#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
|
||||
// CHECK-NEXT: %7 = affine_apply #map1(%i0)
|
||||
// CHECK-NEXT: %8 = affine_apply #map2(%7)
|
||||
// CHECK-NEXT: %9 = affine_apply #map2(%7)
|
||||
// CHECK-NEXT: dma_wait %0[%8, %c0_0], %c128 : memref<2x1xf32>
|
||||
// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32
|
||||
// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %4 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: dma_start %2[%i0], %1[%4#0, %i0], %c128, %0[%4#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
|
||||
// CHECK-NEXT: %5 = affine_apply #map1(%i0)
|
||||
// CHECK-NEXT: %6 = affine_apply #map2(%5)
|
||||
// CHECK-NEXT: %7 = affine_apply #map2(%5)
|
||||
// CHECK-NEXT: dma_wait %0[%6, %c0_0], %c128 : memref<2x1xf32>
|
||||
// CHECK-NEXT: %8 = load %1[%7, %5] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %9 = "compute"(%8) : (f32) -> f32
|
||||
// CHECK-NEXT: store %9, %1[%7, %5] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: for %i1 = 0 to 128 {
|
||||
// CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> ()
|
||||
// CHECK-NEXT: "do_more_compute"(%5, %i1) : (index, index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %12 = affine_apply #map1(%c8)
|
||||
// CHECK-NEXT: %13 = affine_apply #map2(%12)
|
||||
// CHECK-NEXT: %14 = affine_apply #map2(%12)
|
||||
// CHECK-NEXT: dma_wait %0[%13, %c0_0], %c128 : memref<2x1xf32>
|
||||
// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32
|
||||
// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %10 = affine_apply #map1(%c8)
|
||||
// CHECK-NEXT: %11 = affine_apply #map2(%10)
|
||||
// CHECK-NEXT: %12 = affine_apply #map2(%10)
|
||||
// CHECK-NEXT: dma_wait %0[%11, %c0_0], %c128 : memref<2x1xf32>
|
||||
// CHECK-NEXT: %13 = load %1[%12, %10] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
|
||||
// CHECK-NEXT: store %14, %1[%12, %10] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: for %i2 = 0 to 128 {
|
||||
// CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> ()
|
||||
// CHECK-NEXT: "do_more_compute"(%10, %i2) : (index, index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
|
@ -109,7 +107,7 @@ mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>, #map0>, %arg1 : mem
|
|||
// CHECK: dma_wait %2[
|
||||
|
||||
// epilogue for DMA overlap on %arg2
|
||||
// CHECK: dma_wait %0[%37, %c0_2], %c256 : memref<2x2xi32>
|
||||
// CHECK: dma_wait %0[%31, %c0_2], %c256 : memref<2x2xi32>
|
||||
// Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested.
|
||||
// CHECK: dma_start %arg0[
|
||||
// CHECK: dma_start %arg1[
|
||||
|
|
Loading…
Reference in New Issue