diff --git a/llvm/include/llvm/Transforms/Utils/EscapeEnumerator.h b/llvm/include/llvm/Transforms/Utils/EscapeEnumerator.h index e667796c841b..bb5c6f04dd0c 100644 --- a/llvm/include/llvm/Transforms/Utils/EscapeEnumerator.h +++ b/llvm/include/llvm/Transforms/Utils/EscapeEnumerator.h @@ -19,6 +19,8 @@ namespace llvm { +class DomTreeUpdater; + /// EscapeEnumerator - This is a little algorithm to find all escape points /// from a function so that "finally"-style code can be inserted. In addition /// to finding the existing return and unwind instructions, it also (if @@ -33,12 +35,14 @@ class EscapeEnumerator { bool Done; bool HandleExceptions; + DomTreeUpdater *DTU; + public: EscapeEnumerator(Function &F, const char *N = "cleanup", - bool HandleExceptions = true) + bool HandleExceptions = true, DomTreeUpdater *DTU = nullptr) : F(F), CleanupBBName(N), StateBB(F.begin()), StateE(F.end()), Builder(F.getContext()), Done(false), - HandleExceptions(HandleExceptions) {} + HandleExceptions(HandleExceptions), DTU(DTU) {} IRBuilder<> *Next(); }; diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h index c712dda483e4..dfcf289a30ec 100644 --- a/llvm/include/llvm/Transforms/Utils/Local.h +++ b/llvm/include/llvm/Transforms/Utils/Local.h @@ -352,7 +352,8 @@ unsigned changeToUnreachable(Instruction *I, bool UseLLVMTrap, /// InvokeInst is a terminator instruction. Returns the newly split basic /// block. BasicBlock *changeToInvokeAndSplitBasicBlock(CallInst *CI, - BasicBlock *UnwindEdge); + BasicBlock *UnwindEdge, + DomTreeUpdater *DTU = nullptr); /// Replace 'BB's terminator with one that does not have an unwind successor /// block. Rewrites `invoke` to `call`, etc. Updates any PHIs in unwind diff --git a/llvm/lib/CodeGen/ShadowStackGCLowering.cpp b/llvm/lib/CodeGen/ShadowStackGCLowering.cpp index 45427dc41e6e..f2111a021ef5 100644 --- a/llvm/lib/CodeGen/ShadowStackGCLowering.cpp +++ b/llvm/lib/CodeGen/ShadowStackGCLowering.cpp @@ -17,11 +17,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" @@ -67,6 +69,7 @@ public: ShadowStackGCLowering(); bool doInitialization(Module &M) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; bool runOnFunction(Function &F) override; private: @@ -90,6 +93,7 @@ char ShadowStackGCLowering::ID = 0; INITIALIZE_PASS_BEGIN(ShadowStackGCLowering, DEBUG_TYPE, "Shadow Stack GC Lowering", false, false) INITIALIZE_PASS_DEPENDENCY(GCModuleInfo) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(ShadowStackGCLowering, DEBUG_TYPE, "Shadow Stack GC Lowering", false, false) @@ -280,6 +284,10 @@ GetElementPtrInst *ShadowStackGCLowering::CreateGEP(LLVMContext &Context, return dyn_cast(Val); } +void ShadowStackGCLowering::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreserved(); +} + /// runOnFunction - Insert code to maintain the shadow stack. bool ShadowStackGCLowering::runOnFunction(Function &F) { // Quick exit for functions that do not use the shadow stack GC. @@ -297,6 +305,10 @@ bool ShadowStackGCLowering::runOnFunction(Function &F) { if (Roots.empty()) return false; + Optional DTU; + if (auto *DTWP = getAnalysisIfAvailable()) + DTU.emplace(DTWP->getDomTree(), DomTreeUpdater::UpdateStrategy::Lazy); + // Build the constant map and figure the type of the shadow stack entry. Value *FrameMap = GetFrameMap(F); Type *ConcreteStackEntryTy = GetConcreteStackEntryType(F); @@ -348,7 +360,8 @@ bool ShadowStackGCLowering::runOnFunction(Function &F) { AtEntry.CreateStore(NewHeadVal, Head); // For each instruction that escapes... - EscapeEnumerator EE(F, "gc_cleanup"); + EscapeEnumerator EE(F, "gc_cleanup", /*HandleExceptions=*/true, + DTU.hasValue() ? DTU.getPointer() : nullptr); while (IRBuilder<> *AtExit = EE.Next()) { // Pop the entry from the shadow stack. Don't reuse CurrentHead from // AtEntry, since that would make the value live for the entire function. diff --git a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp index accedd5b4ee0..91053338df5f 100644 --- a/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -90,7 +90,7 @@ IRBuilder<> *EscapeEnumerator::Next() { SmallVector Args; for (unsigned I = Calls.size(); I != 0;) { CallInst *CI = cast(Calls[--I]); - changeToInvokeAndSplitBasicBlock(CI, CleanupBB); + changeToInvokeAndSplitBasicBlock(CI, CleanupBB, DTU); } Builder.SetInsertPoint(RI); diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 477ea458c763..948bf08e7cfe 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -2122,15 +2122,16 @@ void llvm::changeToCall(InvokeInst *II, DomTreeUpdater *DTU) { } BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, - BasicBlock *UnwindEdge) { + BasicBlock *UnwindEdge, + DomTreeUpdater *DTU) { BasicBlock *BB = CI->getParent(); // Convert this function call into an invoke instruction. First, split the // basic block. - BasicBlock *Split = - BB->splitBasicBlock(CI->getIterator(), CI->getName() + ".noexc"); + BasicBlock *Split = SplitBlock(BB, CI, DTU, /*LI=*/nullptr, /*MSSAU*/ nullptr, + CI->getName() + ".noexc"); - // Delete the unconditional branch inserted by splitBasicBlock + // Delete the unconditional branch inserted by SplitBlock BB->getInstList().pop_back(); // Create the new invoke instruction. @@ -2150,6 +2151,9 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, II->setCallingConv(CI->getCallingConv()); II->setAttributes(CI->getAttributes()); + if (DTU) + DTU->applyUpdates({{DominatorTree::Insert, BB, UnwindEdge}}); + // Make sure that anything using the call now uses the invoke! This also // updates the CallGraph if present, because it uses a WeakTrackingVH. CI->replaceAllUsesWith(II);