Fix + cleanup for getMemRefRegion()

- determine symbols for the memref region correctly

- this wasn't exposed earlier since we didn't have any test cases where the
  portion of the nest being DMAed for was non-hyperrectangular (i.e., bounds of
  one IV  depending on other IVs within that part)

PiperOrigin-RevId: 233493872
This commit is contained in:
Uday Bondhugula 2019-02-11 15:43:26 -08:00 committed by jpienaar
parent 7897257265
commit f5eed89df0
2 changed files with 43 additions and 17 deletions

View File

@ -228,28 +228,31 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
return false;
}
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
// this memref region is symbolic.
SmallVector<OpPointer<AffineForOp>, 4> outerIVs;
getLoopIVs(*inst, &outerIVs);
assert(loopDepth <= outerIVs.size() && "invalid loop depth");
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
OpPointer<AffineForOp> iv;
if ((iv = getForInductionVarOwner(operand)) &&
llvm::is_contained(outerIVs, iv) == false) {
cst.projectOut(operand);
}
}
// Project out any local variables (these would have been added for any
// mod/divs).
cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
// Set all identifiers appearing after the first 'rank' identifiers as
// symbolic identifiers - so that the ones correspoding to the memref
// dimensions are the dimensional identifiers for the memref region.
cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
// this memref region is symbolic.
SmallVector<OpPointer<AffineForOp>, 4> enclosingIVs;
getLoopIVs(*inst, &enclosingIVs);
assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
enclosingIVs.resize(loopDepth);
SmallVector<Value *, 4> ids;
cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
for (auto *id : ids) {
OpPointer<AffineForOp> iv;
if ((iv = getForInductionVarOwner(id)) &&
llvm::is_contained(enclosingIVs, iv) == false) {
cst.projectOut(id);
}
}
// Project out any local variables (these would have been added for any
// mod/divs).
cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
// Constant fold any symbolic identifiers.
cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
/*num=*/cst.getNumSymbolIds());

View File

@ -429,3 +429,26 @@ func @dma_mixed_loop_blocks() {
// CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 1>
// CHECK: for %i1 = 0 to 256 {
// CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 1>
// -----
// CHECK-LABEL: func @relative_loop_bounds
func @relative_loop_bounds(%arg0: memref<1024xf32>) {
for %i0 = 0 to 1024 {
for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) {
%0 = constant 0.0 : f32
store %0, %arg0[%i2] : memref<1024xf32>
}
}
return
}
// CHECK: [[BUF:%[0-9]+]] = alloc() : memref<1024xf32, 1>
// CHECK-NEXT: [[MEM:%[0-9]+]] = alloc() : memref<1xi32>
// CHECK-NEXT: for %i0 = 0 to 1024 {
// CHECK-NEXT: for %i1 = {{#map[0-9]+}}(%i0) to {{#map[0-9]+}}(%i0) {
// CHECK-NEXT: %cst = constant 0.000000e+00 : f32
// CHECK-NEXT: store %cst, [[BUF]][%i1] : memref<1024xf32, 1>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: dma_start [[BUF]][%c0], %arg0[%c0], %c1024, [[MEM]][%c0] : memref<1024xf32, 1>, memref<1024xf32>, memref<1xi32>
// CHECK-NEXT: dma_wait [[MEM]][%c0], %c1024 : memref<1xi32>