diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 82489acfe417..821f19df31d9 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -228,28 +228,31 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, return false; } - // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which - // this memref region is symbolic. - SmallVector, 4> outerIVs; - getLoopIVs(*inst, &outerIVs); - assert(loopDepth <= outerIVs.size() && "invalid loop depth"); - outerIVs.resize(loopDepth); - for (auto *operand : accessValueMap.getOperands()) { - OpPointer iv; - if ((iv = getForInductionVarOwner(operand)) && - llvm::is_contained(outerIVs, iv) == false) { - cst.projectOut(operand); - } - } - // Project out any local variables (these would have been added for any - // mod/divs). - cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds()); - // Set all identifiers appearing after the first 'rank' identifiers as // symbolic identifiers - so that the ones correspoding to the memref // dimensions are the dimensional identifiers for the memref region. cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank); + // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which + // this memref region is symbolic. + SmallVector, 4> enclosingIVs; + getLoopIVs(*inst, &enclosingIVs); + assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); + enclosingIVs.resize(loopDepth); + SmallVector ids; + cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids); + for (auto *id : ids) { + OpPointer iv; + if ((iv = getForInductionVarOwner(id)) && + llvm::is_contained(enclosingIVs, iv) == false) { + cst.projectOut(id); + } + } + + // Project out any local variables (these would have been added for any + // mod/divs). + cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds()); + // Constant fold any symbolic identifiers. cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(), /*num=*/cst.getNumSymbolIds()); diff --git a/mlir/test/Transforms/dma-generate.mlir b/mlir/test/Transforms/dma-generate.mlir index a954bdb96a13..bf6062e78e86 100644 --- a/mlir/test/Transforms/dma-generate.mlir +++ b/mlir/test/Transforms/dma-generate.mlir @@ -429,3 +429,26 @@ func @dma_mixed_loop_blocks() { // CHECK-NEXT: %3 = load [[BUF]][%c0_0, %c0_0] : memref<256x256xvector<8xf32>, 1> // CHECK: for %i1 = 0 to 256 { // CHECK-NEXT: %4 = load [[BUF]][%i0, %i1] : memref<256x256xvector<8xf32>, 1> + +// ----- + +// CHECK-LABEL: func @relative_loop_bounds +func @relative_loop_bounds(%arg0: memref<1024xf32>) { + for %i0 = 0 to 1024 { + for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0 + 4)(%i0) { + %0 = constant 0.0 : f32 + store %0, %arg0[%i2] : memref<1024xf32> + } + } + return +} +// CHECK: [[BUF:%[0-9]+]] = alloc() : memref<1024xf32, 1> +// CHECK-NEXT: [[MEM:%[0-9]+]] = alloc() : memref<1xi32> +// CHECK-NEXT: for %i0 = 0 to 1024 { +// CHECK-NEXT: for %i1 = {{#map[0-9]+}}(%i0) to {{#map[0-9]+}}(%i0) { +// CHECK-NEXT: %cst = constant 0.000000e+00 : f32 +// CHECK-NEXT: store %cst, [[BUF]][%i1] : memref<1024xf32, 1> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: dma_start [[BUF]][%c0], %arg0[%c0], %c1024, [[MEM]][%c0] : memref<1024xf32, 1>, memref<1024xf32>, memref<1xi32> +// CHECK-NEXT: dma_wait [[MEM]][%c0], %c1024 : memref<1xi32>