From 89e7a76a1cc77f9b67adb866e3a9e09ec5470790 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Mon, 7 Oct 2019 10:03:38 -0700 Subject: [PATCH] fix simplify-affine-structures bug Signed-off-by: Uday Bondhugula Closes tensorflow/mlir#157 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/157 from bondhugula:quickfix bd1fcd79825fc0bd5b4a3e688153fa0993ab703d PiperOrigin-RevId: 273316498 --- mlir/lib/Transforms/SimplifyAffineStructures.cpp | 13 ++++++++++--- mlir/test/Transforms/memref-normalize.mlir | 16 ++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index e243c1bec548..9512ff738aa3 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -102,9 +102,16 @@ void SimplifyAffineStructures::runOnFunction() { } }); - // Turn memrefs' non-identity layouts maps into ones with identity. - func.walk([](AllocOp op) { normalizeMemRef(op); }); + // Turn memrefs' non-identity layouts maps into ones with identity. Collect + // alloc ops first and then process since normalizeMemRef replaces/erases ops + // during memref rewriting. + SmallVector allocOps; + func.walk([&](AllocOp op) { allocOps.push_back(op); }); + for (auto allocOp : allocOps) { + normalizeMemRef(allocOp); + } } static PassRegistration - pass("simplify-affine-structures", "Simplify affine expressions"); + pass("simplify-affine-structures", + "Simplify affine expressions in maps/sets and normalize memrefs"); diff --git a/mlir/test/Transforms/memref-normalize.mlir b/mlir/test/Transforms/memref-normalize.mlir index e9b63624120d..90b363219ee2 100644 --- a/mlir/test/Transforms/memref-normalize.mlir +++ b/mlir/test/Transforms/memref-normalize.mlir @@ -22,10 +22,12 @@ func @permute() { // CHECK-NEXT: dealloc [[MEM]] // CHECK-NEXT: return -// CHECK-LABEL: func @shift() -func @shift() { - // CHECK-NOT: memref<64xf32, (d0) -> (d0 + 1)> +// CHECK-LABEL: func @shift +func @shift(%idx : index) { + // CHECK-NEXT: alloc() : memref<65xf32> %A = alloc() : memref<64xf32, (d0) -> (d0 + 1)> + // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> + affine.load %A[%idx] : memref<64xf32, (d0) -> (d0 + 1)> affine.for %i = 0 to 64 { affine.load %A[%i] : memref<64xf32, (d0) -> (d0 + 1)> // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> @@ -59,10 +61,12 @@ func @invalid_map() { } // A tiled layout. -// CHECK-LABEL: func @data_tiling() -func @data_tiling() { +// CHECK-LABEL: func @data_tiling +func @data_tiling(%idx : index) { + // CHECK: alloc() : memref<8x32x8x16xf32> %A = alloc() : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)> - // CHECK: %{{.*}} = alloc() : memref<8x32x8x16xf32> + // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16] + affine.load %A[%idx, %idx] : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)> return }