[MLIR] Correct block merge bug

Block merging in MLIR will incorrectly merge blocks with operations whose values are used outside of that block. This change forbids this behavior and provides a test where it is illegal to perform such a merge.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D91745
This commit is contained in:
William S. Moses 2020-11-20 19:05:09 +01:00 committed by Alex Zinenko
parent 88e6208562
commit f5c5fd1c50
2 changed files with 42 additions and 18 deletions

View File

@ -464,10 +464,6 @@ private:
/// A set of operand+index pairs that correspond to operands that need to be
/// replaced by arguments when the cluster gets merged.
std::set<std::pair<int, int>> operandsToMerge;
/// A map of operations with external uses to a replacement within the leader
/// block.
DenseMap<Operation *, Operation *> opsToReplace;
};
} // end anonymous namespace
@ -480,7 +476,6 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
// A set of operands that mismatch between the leader and the new block.
SmallVector<std::pair<int, int>, 8> mismatchedOperands;
SmallVector<std::pair<Operation *, Operation *>, 2> newOpsToReplace;
auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
@ -519,9 +514,16 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
return failure();
}
// If the rhs has external uses, it will need to be replaced.
if (rhsIt->isUsedOutsideOfBlock(mergeBlock))
newOpsToReplace.emplace_back(&*rhsIt, &*lhsIt);
// If the lhs or rhs has external uses, the blocks cannot be merged as the
// merged version of this operation will not be either the lhs or rhs
// alone (thus semantically incorrect), but some mix dependening on which
// block preceeded this.
// TODO allow merging of operations when one block does not dominate the
// other
if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
return failure();
}
}
// Make sure that the block sizes are equivalent.
if (lhsIt != lhsE || rhsIt != rhsE)
@ -529,7 +531,6 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
// If we get here, the blocks are equivalent and can be merged.
operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
opsToReplace.insert(newOpsToReplace.begin(), newOpsToReplace.end());
blocksToMerge.insert(blockData.block);
return success();
}
@ -561,10 +562,6 @@ LogicalResult BlockMergeCluster::merge() {
!llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
return failure();
// Replace any necessary operations.
for (std::pair<Operation *, Operation *> &it : opsToReplace)
it.first->replaceAllUsesWith(it.second);
// Collect the iterators for each of the blocks to merge. We will walk all
// of the iterators at once to avoid operand index invalidation.
SmallVector<Block::iterator, 2> blockIterators;

View File

@ -174,26 +174,24 @@ func @contains_regions(%cond : i1) {
return
}
// Check that properly handles back edges and the case where a value from one
// block is used in another.
// Check that properly handles back edges.
// CHECK-LABEL: func @mismatch_loop(
// CHECK-SAME: %[[ARG:.*]]: i1, %[[ARG2:.*]]: i1
func @mismatch_loop(%cond : i1, %cond2 : i1) {
// CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
// CHECK: cond_br %{{.*}}, ^bb1(%[[ARG2]] : i1), ^bb2
%cond3 = "foo.op"() : () -> (i1)
cond_br %cond, ^bb2, ^bb3
^bb1:
// CHECK: ^bb1(%[[ARG3:.*]]: i1):
// CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
// CHECK-NEXT: cond_br %[[ARG3]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2
%ignored = "foo.op"() : () -> (i1)
cond_br %cond3, ^bb1, ^bb3
^bb2:
%cond3 = "foo.op"() : () -> (i1)
cond_br %cond2, ^bb1, ^bb3
^bb3:
@ -224,3 +222,32 @@ func @mismatch_operand_types(%arg0 : i1, %arg1 : memref<i32>, %arg2 : memref<i1>
store %true, %arg2[] : memref<i1>
br ^bb1
}
// Check that it is illegal to merge blocks containing an operand
// with an external user. Incorrectly performing the optimization
// anyways will result in print(merged, merged) rather than
// distinct operands.
func private @print(%arg0: i32, %arg1: i32)
// CHECK-LABEL: @nomerge
func @nomerge(%arg0: i32, %i: i32) {
%c1_i32 = constant 1 : i32
%icmp = cmpi "slt", %i, %arg0 : i32
cond_br %icmp, ^bb2, ^bb3
^bb2: // pred: ^bb1
%ip1 = addi %i, %c1_i32 : i32
br ^bb4(%ip1 : i32)
^bb7: // pred: ^bb5
%jp1 = addi %j, %c1_i32 : i32
br ^bb4(%jp1 : i32)
^bb4(%j: i32): // 2 preds: ^bb2, ^bb7
%jcmp = cmpi "slt", %j, %arg0 : i32
// CHECK-NOT: call @print(%[[arg1:.+]], %[[arg1]])
call @print(%j, %ip1) : (i32, i32) -> ()
cond_br %jcmp, ^bb7, ^bb3
^bb3: // pred: ^bb1
return
}