[mlir] Fix bug in copy removal

A crash could happen due to copy removal. The bug is fixed and two more
test cases are added.

Differential Revision: https://reviews.llvm.org/D87128
This commit is contained in:
Ehsan Toosi 2020-08-24 13:19:50 +02:00
parent df63eedef6
commit 4e9f4d0b9d
2 changed files with 93 additions and 8 deletions

View File

@ -30,16 +30,35 @@ public:
reuseCopySourceAsTarget(copyOp);
reuseCopyTargetAsSource(copyOp);
});
for (std::pair<Value, Value> &pair : replaceList)
pair.first.replaceAllUsesWith(pair.second);
for (Operation *op : eraseList)
op->erase();
}
private:
/// List of operations that need to be removed.
DenseSet<Operation *> eraseList;
llvm::SmallPtrSet<Operation *, 4> eraseList;
/// List of values that need to be replaced with their counterparts.
llvm::SmallDenseSet<std::pair<Value, Value>, 4> replaceList;
/// Returns the allocation operation for `value` in `block` if it exists.
/// nullptr otherwise.
Operation *getAllocationOpInBlock(Value value, Block *block) {
assert(block && "Block cannot be null");
Operation *op = value.getDefiningOp();
if (op && op->getBlock() == block) {
auto effects = dyn_cast<MemoryEffectOpInterface>(op);
if (effects && effects.hasEffect<Allocate>())
return op;
}
return nullptr;
}
/// Returns the deallocation operation for `value` in `block` if it exists.
Operation *getDeallocationInBlock(Value value, Block *block) {
/// nullptr otherwise.
Operation *getDeallocationOpInBlock(Value value, Block *block) {
assert(block && "Block cannot be null");
auto valueUsers = value.getUsers();
auto it = llvm::find_if(valueUsers, [&](Operation *op) {
@ -119,9 +138,10 @@ private:
Value to = copyOp.getTarget();
Operation *copy = copyOp.getOperation();
Block *copyBlock = copy->getBlock();
Operation *fromDefiningOp = from.getDefiningOp();
Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
Operation *toDefiningOp = to.getDefiningOp();
Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock);
if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp ||
!areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) ||
hasUsersBetween(to, toDefiningOp, copy) ||
@ -129,7 +149,7 @@ private:
hasMemoryEffectOpBetween(copy, fromFreeingOp))
return;
to.replaceAllUsesWith(from);
replaceList.insert({to, from});
eraseList.insert(copy);
eraseList.insert(toDefiningOp);
eraseList.insert(fromFreeingOp);
@ -169,8 +189,9 @@ private:
Value to = copyOp.getTarget();
Operation *copy = copyOp.getOperation();
Operation *fromDefiningOp = from.getDefiningOp();
Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
Block *copyBlock = copy->getBlock();
Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock);
Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
if (!fromDefiningOp || !fromFreeingOp ||
!areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) ||
hasUsersBetween(to, fromDefiningOp, copy) ||
@ -178,7 +199,7 @@ private:
hasMemoryEffectOpBetween(copy, fromFreeingOp))
return;
from.replaceAllUsesWith(to);
replaceList.insert({from, to});
eraseList.insert(copy);
eraseList.insert(fromDefiningOp);
eraseList.insert(fromFreeingOp);

View File

@ -283,3 +283,67 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
dealloc %temp : memref<2xf32>
return
}
// -----
// The only redundant copy is linalg.copy(%4, %5)
// CHECK-LABEL: func @loop_alloc
func @loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) {
// CHECK: %{{.*}} = alloc()
%0 = alloc() : memref<2xf32>
dealloc %0 : memref<2xf32>
// CHECK: %{{.*}} = alloc()
%1 = alloc() : memref<2xf32>
// CHECK: linalg.copy
linalg.copy(%arg3, %1) : memref<2xf32>, memref<2xf32>
%2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) {
%3 = cmpi "eq", %arg5, %arg1 : index
// CHECK: dealloc
dealloc %arg6 : memref<2xf32>
// CHECK: %[[PERCENT4:.*]] = alloc()
%4 = alloc() : memref<2xf32>
// CHECK-NOT: alloc
// CHECK-NOT: linalg.copy
// CHECK-NOT: dealloc
%5 = alloc() : memref<2xf32>
linalg.copy(%4, %5) : memref<2xf32>, memref<2xf32>
dealloc %4 : memref<2xf32>
// CHECK: %[[PERCENT6:.*]] = alloc()
%6 = alloc() : memref<2xf32>
// CHECK: linalg.copy(%[[PERCENT4]], %[[PERCENT6]])
linalg.copy(%5, %6) : memref<2xf32>, memref<2xf32>
scf.yield %6 : memref<2xf32>
}
// CHECK: linalg.copy
linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32>
dealloc %2 : memref<2xf32>
return
}
// -----
// The linalg.copy operation can be removed in addition to alloc and dealloc
// operations. All uses of %0 is then replaced with %arg2.
// CHECK-LABEL: func @check_with_affine_dialect
func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[RES:.*]]: memref<4xf32>)
// CHECK-NOT: alloc
%0 = alloc() : memref<4xf32>
affine.for %arg3 = 0 to 4 {
%5 = affine.load %arg0[%arg3] : memref<4xf32>
%6 = affine.load %arg1[%arg3] : memref<4xf32>
%7 = cmpf "ogt", %5, %6 : f32
// CHECK: %[[SELECT_RES:.*]] = select
%8 = select %7, %5, %6 : f32
// CHECK-NEXT: affine.store %[[SELECT_RES]], %[[RES]]
affine.store %8, %0[%arg3] : memref<4xf32>
}
// CHECK-NOT: linalg.copy
// CHECK-NOT: dealloc
"linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
dealloc %0 : memref<4xf32>
//CHECK: return
return
}