fix simplify-affine-structures bug

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#157

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/157 from bondhugula:quickfix bd1fcd79825fc0bd5b4a3e688153fa0993ab703d
PiperOrigin-RevId: 273316498
This commit is contained in:
Uday Bondhugula 2019-10-07 10:03:38 -07:00 committed by A. Unique TensorFlower
parent 9f11b0e12f
commit 89e7a76a1c
2 changed files with 20 additions and 9 deletions

View File

@ -102,9 +102,16 @@ void SimplifyAffineStructures::runOnFunction() {
}
});
// Turn memrefs' non-identity layouts maps into ones with identity.
func.walk([](AllocOp op) { normalizeMemRef(op); });
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
// alloc ops first and then process since normalizeMemRef replaces/erases ops
// during memref rewriting.
SmallVector<AllocOp, 4> allocOps;
func.walk([&](AllocOp op) { allocOps.push_back(op); });
for (auto allocOp : allocOps) {
normalizeMemRef(allocOp);
}
}
static PassRegistration<SimplifyAffineStructures>
pass("simplify-affine-structures", "Simplify affine expressions");
pass("simplify-affine-structures",
"Simplify affine expressions in maps/sets and normalize memrefs");

View File

@ -22,10 +22,12 @@ func @permute() {
// CHECK-NEXT: dealloc [[MEM]]
// CHECK-NEXT: return
// CHECK-LABEL: func @shift()
func @shift() {
// CHECK-NOT: memref<64xf32, (d0) -> (d0 + 1)>
// CHECK-LABEL: func @shift
func @shift(%idx : index) {
// CHECK-NEXT: alloc() : memref<65xf32>
%A = alloc() : memref<64xf32, (d0) -> (d0 + 1)>
// CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
affine.load %A[%idx] : memref<64xf32, (d0) -> (d0 + 1)>
affine.for %i = 0 to 64 {
affine.load %A[%i] : memref<64xf32, (d0) -> (d0 + 1)>
// CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
@ -59,10 +61,12 @@ func @invalid_map() {
}
// A tiled layout.
// CHECK-LABEL: func @data_tiling()
func @data_tiling() {
// CHECK-LABEL: func @data_tiling
func @data_tiling(%idx : index) {
// CHECK: alloc() : memref<8x32x8x16xf32>
%A = alloc() : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>
// CHECK: %{{.*}} = alloc() : memref<8x32x8x16xf32>
// CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16]
affine.load %A[%idx, %idx] : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>
return
}