[MemCpyOpt] Check all access for MemoryUses in writtenBetween.

Currently writtenBetween can miss clobbers of Loc between End and Start,
if End is a MemoryUse.

To guarantee we see all write clobbers of Loc between Start and End
for MemoryUses, restrict to Start and End being in the same block
and check all accesses between them.

This fixes 2 mis-compiles illustrated in
llvm/test/Transforms/MemCpyOpt/memcpy-byval-forwarding-clobbers.ll

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D119929
This commit is contained in:
Florian Hahn 2022-02-21 16:54:02 +00:00
parent 3a3d9ae545
commit 7662d1687b
No known key found for this signature in database
GPG Key ID: CF59919C6547A668
2 changed files with 22 additions and 8 deletions

View File

@ -352,9 +352,25 @@ static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc,
// Check for mod of Loc between Start and End, excluding both boundaries.
// Start and End can be in different blocks.
static bool writtenBetween(MemorySSA *MSSA, MemoryLocation Loc,
const MemoryUseOrDef *Start,
static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA,
MemoryLocation Loc, const MemoryUseOrDef *Start,
const MemoryUseOrDef *End) {
if (isa<MemoryUse>(End)) {
// For MemoryUses, getClobberingMemoryAccess may skip non-clobbering writes.
// Manually check read accesses between Start and End, if they are in the
// same block, for clobbers. Otherwise assume Loc is clobbered.
return Start->getBlock() != End->getBlock() ||
any_of(
make_range(std::next(Start->getIterator()), End->getIterator()),
[&AA, Loc](const MemoryAccess &Acc) {
if (isa<MemoryUse>(&Acc))
return false;
Instruction *AccInst =
cast<MemoryUseOrDef>(&Acc)->getMemoryInst();
return isModSet(AA.getModRefInfo(AccInst, Loc));
});
}
// TODO: Only walk until we hit Start.
MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess(
End->getDefiningAccess(), Loc);
@ -1118,7 +1134,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M,
// then we could still perform the xform by moving M up to the first memcpy.
// TODO: It would be sufficient to check the MDep source up to the memcpy
// size of M, rather than MDep.
if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep),
if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep),
MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M)))
return false;
@ -1557,7 +1573,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) {
// *b = 42;
// foo(*a)
// It would be invalid to transform the second memcpy into foo(*b).
if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep),
if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep),
MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB)))
return false;

View File

@ -13,7 +13,6 @@ declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noa
; %a.2's lifetime ends before the call to @check. Cannot replace
; %a.1 with %a.2 in the call to @check.
; FIXME: Find lifetime.end, prevent optimization.
define i1 @alloca_forwarding_lifetime_end_clobber() {
; CHECK-LABEL: @alloca_forwarding_lifetime_end_clobber(
; CHECK-NEXT: entry:
@ -26,7 +25,7 @@ define i1 @alloca_forwarding_lifetime_end_clobber() {
; CHECK-NEXT: store i8 0, i8* [[BC_A_2]], align 1
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[BC_A_1]], i8* [[BC_A_2]], i64 8, i1 false)
; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 8, i8* [[BC_A_2]])
; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_2]])
; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_1]])
; CHECK-NEXT: ret i1 [[CALL]]
;
entry:
@ -46,7 +45,6 @@ entry:
; There is a call clobbering %a.2 before the call to @check. Cannot replace
; %a.1 with %a.2 in the call to @check.
; FIXME: Find clobber, prevent optimization.
define i1 @alloca_forwarding_call_clobber() {
; CHECK-LABEL: @alloca_forwarding_call_clobber(
; CHECK-NEXT: entry:
@ -59,7 +57,7 @@ define i1 @alloca_forwarding_call_clobber() {
; CHECK-NEXT: store i8 0, i8* [[BC_A_2]], align 1
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[BC_A_1]], i8* [[BC_A_2]], i64 8, i1 false)
; CHECK-NEXT: call void @clobber(i8* [[BC_A_2]])
; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_2]])
; CHECK-NEXT: [[CALL:%.*]] = call i1 @check(i64* byval(i64) align 8 [[A_1]])
; CHECK-NEXT: ret i1 [[CALL]]
;
entry: