diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b5c771e9f276..7a5d87172ddd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -311,25 +311,28 @@ struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
 
   LogicalResult matchAndRewrite(AllocaScopeOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) {
-      bool hasPotentialAlloca =
-          op->walk([&](Operation *alloc) {
-              if (alloc == op)
-                return WalkResult::advance();
-              if (isOpItselfPotentialAutomaticAllocation(alloc))
-                return WalkResult::interrupt();
+    bool hasPotentialAlloca =
+        op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
+            if (alloc == op)
               return WalkResult::advance();
-            }).wasInterrupted();
-      if (hasPotentialAlloca)
+            if (isOpItselfPotentialAutomaticAllocation(alloc))
+              return WalkResult::interrupt();
+            if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
+              return WalkResult::skip();
+            return WalkResult::advance();
+          }).wasInterrupted();
+
+    // If this contains no potential allocation, it is always legal to
+    // inline. Otherwise, consider two conditions:
+    if (hasPotentialAlloca) {
+      // If the parent isn't an allocation scope, or we are not the last
+      // non-terminator op in the parent, we will extend the lifetime.
+      if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
+        return failure();
+      if (!lastNonTerminatorInRegion(op))
         return failure();
     }
 
-    // Only apply to if this is this last non-terminator
-    // op in the block (lest lifetime be extended) of a one
-    // block region
-    if (!lastNonTerminatorInRegion(op))
-      return failure();
-
     Block *block = &op.getRegion().front();
     Operation *terminator = block->getTerminator();
     ValueRange results = terminator->getOperands();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 96fff29db734..0679cbdfaf01 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -644,6 +644,32 @@ func @scopeMerge4() {
 // CHECK:     return
 // CHECK:   }
 
+func @scopeMerge5() {
+  "test.region"() ({
+    memref.alloca_scope {
+      affine.parallel (%arg) = (0) to (64) {
+        %a = memref.alloca(%arg) : memref<?xi64>
+        "test.use"(%a) : (memref<?xi64>) -> ()
+      }
+    }
+    "test.op"() : () -> ()
+    "test.terminator"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// CHECK:   func @scopeMerge5() {
+// CHECK:     "test.region"() ({
+// CHECK:       affine.parallel (%[[cnt:.+]]) = (0) to (64) {
+// CHECK:         %[[alloc:.+]] = memref.alloca(%[[cnt]]) : memref<?xi64>
+// CHECK:         "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
+// CHECK:       }
+// CHECK:       "test.op"() : () -> ()
+// CHECK:       "test.terminator"() : () -> ()
+// CHECK:     }) : () -> ()
+// CHECK:     return
+// CHECK:   }
+
 func @scopeInline(%arg : memref<index>) {
   %cnt = "test.count"() : () -> index
   "test.region"() ({