diff --git a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h index 5426482ff526..877aa40e1a3a 100644 --- a/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h +++ b/llvm/include/llvm/Transforms/Scalar/MemCpyOptimizer.h @@ -61,8 +61,9 @@ private: bool processMemSet(MemSetInst *SI, BasicBlock::iterator &BBI); bool processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI); bool processMemMove(MemMoveInst *M); - bool performCallSlotOptzn(Instruction *cpy, Value *cpyDst, Value *cpySrc, - uint64_t cpyLen, Align cpyAlign, CallInst *C); + bool performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore, + Value *cpyDst, Value *cpySrc, uint64_t cpyLen, + Align cpyAlign, CallInst *C); bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep); bool processMemSetMemCpyDependence(MemCpyInst *MemCpy, MemSetInst *MemSet); bool performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet); diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 68fcf91b3464..49f76d37ec0d 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -659,8 +659,6 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (C) { // Check that nothing touches the dest of the "copy" between // the call and the store. - Value *CpyDest = SI->getPointerOperand()->stripPointerCasts(); - bool CpyDestIsLocal = isa(CpyDest); MemoryLocation StoreLoc = MemoryLocation::get(SI); for (BasicBlock::iterator I = --SI->getIterator(), E = C->getIterator(); I != E; --I) { @@ -668,18 +666,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { C = nullptr; break; } - // The store to dest may never happen if an exception can be thrown - // between the load and the store. - if (I->mayThrow() && !CpyDestIsLocal) { - C = nullptr; - break; - } } } if (C) { bool changed = performCallSlotOptzn( - LI, SI->getPointerOperand()->stripPointerCasts(), + LI, SI, SI->getPointerOperand()->stripPointerCasts(), LI->getPointerOperand()->stripPointerCasts(), DL.getTypeStoreSize(SI->getOperand(0)->getType()), commonAlignment(SI->getAlign(), LI->getAlign()), C); @@ -754,7 +746,8 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { /// Takes a memcpy and a call that it depends on, /// and checks for the possibility of a call slot optimization by having /// the call write its result directly into the destination of the memcpy. -bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, +bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, + Instruction *cpyStore, Value *cpyDest, Value *cpySrc, uint64_t cpyLen, Align cpyAlign, CallInst *C) { // The general transformation to keep in mind is @@ -785,7 +778,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, if (!srcArraySize) return false; - const DataLayout &DL = cpy->getModule()->getDataLayout(); + const DataLayout &DL = cpyLoad->getModule()->getDataLayout(); uint64_t srcSize = DL.getTypeAllocSize(srcAlloca->getAllocatedType()) * srcArraySize->getZExtValue(); @@ -795,6 +788,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, // Check that accessing the first srcSize bytes of dest will not cause a // trap. Otherwise the transform is invalid since it might cause a trap // to occur earlier than it otherwise would. + // TODO: Use isDereferenceablePointer() API instead. if (AllocaInst *A = dyn_cast(cpyDest)) { // The destination is an alloca. Check it is larger than srcSize. ConstantInt *destArraySize = dyn_cast(A->getArraySize()); @@ -807,10 +801,6 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, if (destSize < srcSize) return false; } else if (Argument *A = dyn_cast(cpyDest)) { - // The store to dest may never happen if the call can throw. - if (C->mayThrow()) - return false; - if (A->getDereferenceableBytes() < srcSize) { // If the destination is an sret parameter then only accesses that are // outside of the returned struct type can trap. @@ -833,6 +823,30 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, return false; } + // Make sure that nothing can observe cpyDest being written early. There are + // a number of cases to consider: + // 1. cpyDest cannot be accessed between C and cpyStore as a precondition of + // the transform. + // 2. C itself may not access cpyDest (prior to the transform). This is + // checked further below. + // 3. If cpyDest is accessible to the caller of this function (potentially + // captured and not based on an alloca), we need to ensure that we cannot + // unwind between C and cpyStore. This is checked here. + // 4. If cpyDest is potentially captured, there may be accesses to it from + // another thread. In this case, we need to check that cpyStore is + // guaranteed to be executed if C is. As it is a non-atomic access, it + // renders accesses from other threads undefined. + // TODO: This is currently not checked. + if (!isa(cpyDest)) { + assert(C->getParent() == cpyStore->getParent() && + "call and copy must be in the same block"); + for (const Instruction &I : make_range(C->getIterator(), + cpyStore->getIterator())) { + if (I.mayThrow()) + return false; + } + } + // Check that dest points to memory that is at least as aligned as src. Align srcAlign = srcAlloca->getAlign(); bool isDestSufficientlyAligned = srcAlign <= cpyAlign; @@ -867,7 +881,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, if (IT->isLifetimeStartOrEnd()) continue; - if (U != C && U != cpy) + if (U != C && U != cpyLoad) return false; } @@ -941,7 +955,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, LLVMContext::MD_noalias, LLVMContext::MD_invariant_group, LLVMContext::MD_access_group}; - combineMetadata(C, cpy, KnownIDs, true); + combineMetadata(C, cpyLoad, KnownIDs, true); ++NumCallSlot; return true; @@ -1242,7 +1256,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { // of conservatively taking the minimum? Align Alignment = std::min(M->getDestAlign().valueOrOne(), M->getSourceAlign().valueOrOne()); - if (performCallSlotOptzn(M, M->getDest(), M->getSource(), + if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(), CopySize->getZExtValue(), Alignment, C)) { eraseInstruction(M); ++NumMemCpyInstr; diff --git a/llvm/test/Transforms/MemCpyOpt/callslot.ll b/llvm/test/Transforms/MemCpyOpt/callslot.ll index a4cfd53f4d24..3b495e5f3fa7 100644 --- a/llvm/test/Transforms/MemCpyOpt/callslot.ll +++ b/llvm/test/Transforms/MemCpyOpt/callslot.ll @@ -91,10 +91,9 @@ define void @throw_between_call_and_mempy(i8* dereferenceable(16) %dest.i8) { ; CHECK-LABEL: @throw_between_call_and_mempy( ; CHECK-NEXT: [[SRC:%.*]] = alloca [16 x i8], align 1 ; CHECK-NEXT: [[SRC_I8:%.*]] = bitcast [16 x i8]* [[SRC]] to i8* -; CHECK-NEXT: [[DEST_I81:%.*]] = bitcast i8* [[DEST_I8:%.*]] to [16 x i8]* -; CHECK-NEXT: [[DEST_I812:%.*]] = bitcast [16 x i8]* [[DEST_I81]] to i8* -; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* [[DEST_I812]], i8 0, i64 16, i1 false) +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* [[SRC_I8]], i8 0, i64 16, i1 false) ; CHECK-NEXT: call void @may_throw() [[ATTR2:#.*]] +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* [[DEST_I8:%.*]], i8 0, i64 16, i1 false) ; CHECK-NEXT: ret void ; %src = alloca [16 x i8]