From 78d191089182842f6360547c48993e44de9e1437 Mon Sep 17 00:00:00 2001 From: Heejin Ahn Date: Tue, 21 Aug 2018 21:23:07 +0000 Subject: [PATCH] [WebAssembly] Restore __stack_pointer after catch instructions Summary: After the stack is unwound due to a thrown exception, the `__stack_pointer` global can point to an invalid address. This inserts instructions that restore `__stack_pointer` global. Reviewers: jgravelle-google, dschuff Subscribers: mgorny, sbc100, sunfish, llvm-commits Differential Revision: https://reviews.llvm.org/D50980 llvm-svn: 340339 --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 1 + llvm/lib/Target/WebAssembly/WebAssembly.h | 2 + .../WebAssemblyEHRestoreStackPointer.cpp | 78 +++++++++++++++++++ .../WebAssembly/WebAssemblyFrameLowering.cpp | 20 +++-- .../WebAssembly/WebAssemblyFrameLowering.h | 10 ++- .../WebAssembly/WebAssemblyTargetMachine.cpp | 4 + llvm/test/CodeGen/WebAssembly/exception.ll | 5 ++ 7 files changed, 114 insertions(+), 6 deletions(-) create mode 100644 llvm/lib/Target/WebAssembly/WebAssemblyEHRestoreStackPointer.cpp diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index a928f110efe0..549229ad572b 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -20,6 +20,7 @@ add_llvm_target(WebAssemblyCodeGen WebAssemblyCFGStackify.cpp WebAssemblyCFGSort.cpp WebAssemblyLateEHPrepare.cpp + WebAssemblyEHRestoreStackPointer.cpp WebAssemblyExceptionInfo.cpp WebAssemblyExplicitLocals.cpp WebAssemblyFastISel.cpp diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h index 05b7b21fb597..87975cad02a8 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -39,6 +39,7 @@ FunctionPass *createWebAssemblyArgumentMove(); FunctionPass *createWebAssemblySetP2AlignOperands(); // Late passes. +FunctionPass *createWebAssemblyEHRestoreStackPointer(); FunctionPass *createWebAssemblyReplacePhysRegs(); FunctionPass *createWebAssemblyPrepareForLiveIntervals(); FunctionPass *createWebAssemblyOptimizeLiveIntervals(); @@ -63,6 +64,7 @@ void initializeFixFunctionBitcastsPass(PassRegistry &); void initializeOptimizeReturnedPass(PassRegistry &); void initializeWebAssemblyArgumentMovePass(PassRegistry &); void initializeWebAssemblySetP2AlignOperandsPass(PassRegistry &); +void initializeWebAssemblyEHRestoreStackPointerPass(PassRegistry &); void initializeWebAssemblyReplacePhysRegsPass(PassRegistry &); void initializeWebAssemblyPrepareForLiveIntervalsPass(PassRegistry &); void initializeWebAssemblyOptimizeLiveIntervalsPass(PassRegistry &); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyEHRestoreStackPointer.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyEHRestoreStackPointer.cpp new file mode 100644 index 000000000000..244fb84d0b36 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyEHRestoreStackPointer.cpp @@ -0,0 +1,78 @@ +//===-- WebAssemblyEHRestoreStackPointer.cpp - __stack_pointer restoration ===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// After the stack is unwound due to a thrown exception, the __stack_pointer +/// global can point to an invalid address. This inserts instructions that +/// restore __stack_pointer global. +/// +//===----------------------------------------------------------------------===// + +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "WebAssembly.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyUtilities.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/MC/MCAsmInfo.h" +using namespace llvm; + +#define DEBUG_TYPE "wasm-eh-restore-stack-pointer" + +namespace { +class WebAssemblyEHRestoreStackPointer final : public MachineFunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + WebAssemblyEHRestoreStackPointer() : MachineFunctionPass(ID) {} + + StringRef getPassName() const override { + return "WebAssembly Restore Stack Pointer for Exception Handling"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + MachineFunctionPass::getAnalysisUsage(AU); + } + + bool runOnMachineFunction(MachineFunction &MF) override; +}; +} // end anonymous namespace + +char WebAssemblyEHRestoreStackPointer::ID = 0; +INITIALIZE_PASS(WebAssemblyEHRestoreStackPointer, DEBUG_TYPE, + "Restore Stack Pointer for Exception Handling", true, false) + +FunctionPass *llvm::createWebAssemblyEHRestoreStackPointer() { + return new WebAssemblyEHRestoreStackPointer(); +} + +bool WebAssemblyEHRestoreStackPointer::runOnMachineFunction( + MachineFunction &MF) { + const auto *FrameLowering = static_cast( + MF.getSubtarget().getFrameLowering()); + if (!FrameLowering->needsPrologForEH(MF)) + return false; + bool Changed = false; + + for (auto &MBB : MF) { + if (!MBB.isEHPad()) + continue; + Changed = true; + + // Insert __stack_pointer restoring instructions at the beginning of each EH + // pad, after the catch instruction. (Catch instructions may have been + // reordered, and catch_all instructions have not been inserted yet, but + // those cases are handled in LateEHPrepare). + auto InsertPos = MBB.begin(); + if (WebAssembly::isCatch(*MBB.begin())) + InsertPos++; + FrameLowering->writeSPToGlobal(WebAssembly::SP32, MF, MBB, InsertPos, + MBB.begin()->getDebugLoc()); + } + return Changed; +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp index ace2f0ecc20e..9e33ed75f93a 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp @@ -30,6 +30,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineModuleInfoImpls.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/MC/MCAsmInfo.h" #include "llvm/Support/Debug.h" using namespace llvm; @@ -78,13 +79,23 @@ bool WebAssemblyFrameLowering::hasReservedCallFrame( return !MF.getFrameInfo().hasVarSizedObjects(); } +// In function with EH pads, we need to make a copy of the value of +// __stack_pointer global in SP32 register, in order to use it when restoring +// __stack_pointer after an exception is caught. +bool WebAssemblyFrameLowering::needsPrologForEH( + const MachineFunction &MF) const { + auto EHType = MF.getTarget().getMCAsmInfo()->getExceptionHandlingType(); + return EHType == ExceptionHandling::Wasm && + MF.getFunction().hasPersonalityFn() && MF.getFrameInfo().hasCalls(); +} /// Returns true if this function needs a local user-space stack pointer. /// Unlike a machine stack pointer, the wasm user stack pointer is a global /// variable, so it is loaded into a register in the prolog. bool WebAssemblyFrameLowering::needsSP(const MachineFunction &MF, const MachineFrameInfo &MFI) const { - return MFI.getStackSize() || MFI.adjustsStack() || hasFP(MF); + return MFI.getStackSize() || MFI.adjustsStack() || hasFP(MF) || + needsPrologForEH(MF); } /// Returns true if the local user-space stack pointer needs to be written back @@ -97,10 +108,9 @@ bool WebAssemblyFrameLowering::needsSPWriteback( MF.getFunction().hasFnAttribute(Attribute::NoRedZone); } -static void writeSPToGlobal(unsigned SrcReg, MachineFunction &MF, - MachineBasicBlock &MBB, - MachineBasicBlock::iterator &InsertStore, - const DebugLoc &DL) { +void WebAssemblyFrameLowering::writeSPToGlobal( + unsigned SrcReg, MachineFunction &MF, MachineBasicBlock &MBB, + MachineBasicBlock::iterator &InsertStore, const DebugLoc &DL) const { const auto *TII = MF.getSubtarget().getInstrInfo(); const char *ES = "__stack_pointer"; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.h index fe23e418a3f1..e888aaf3aef0 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.h @@ -45,7 +45,15 @@ class WebAssemblyFrameLowering final : public TargetFrameLowering { bool hasFP(const MachineFunction &MF) const override; bool hasReservedCallFrame(const MachineFunction &MF) const override; - private: + bool needsPrologForEH(const MachineFunction &MF) const; + + /// Write SP back to __stack_pointer global. + void writeSPToGlobal(unsigned SrcReg, MachineFunction &MF, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator &InsertStore, + const DebugLoc &DL) const; + +private: bool hasBP(const MachineFunction &MF) const; bool needsSP(const MachineFunction &MF, const MachineFrameInfo &MFI) const; bool needsSPWriteback(const MachineFunction &MF, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index 7c10f022cbbc..1cc688599a99 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -58,6 +58,7 @@ extern "C" void LLVMInitializeWebAssemblyTarget() { initializeOptimizeReturnedPass(PR); initializeWebAssemblyArgumentMovePass(PR); initializeWebAssemblySetP2AlignOperandsPass(PR); + initializeWebAssemblyEHRestoreStackPointerPass(PR); initializeWebAssemblyReplacePhysRegsPass(PR); initializeWebAssemblyPrepareForLiveIntervalsPass(PR); initializeWebAssemblyOptimizeLiveIntervalsPass(PR); @@ -280,6 +281,9 @@ void WebAssemblyPassConfig::addPostRegAlloc() { void WebAssemblyPassConfig::addPreEmitPass() { TargetPassConfig::addPreEmitPass(); + // Restore __stack_pointer global after an exception is thrown. + addPass(createWebAssemblyEHRestoreStackPointer()); + // Now that we have a prologue and epilogue and all frame indices are // rewritten, eliminate SP and FP. This allows them to be stackified, // colored, and numbered with the rest of the registers. diff --git a/llvm/test/CodeGen/WebAssembly/exception.ll b/llvm/test/CodeGen/WebAssembly/exception.ll index 2519ffcc78c0..a0256d238915 100644 --- a/llvm/test/CodeGen/WebAssembly/exception.ll +++ b/llvm/test/CodeGen/WebAssembly/exception.ll @@ -19,9 +19,11 @@ define void @test_throw() { } ; CHECK-LABEL: test_catch_rethrow: +; CHECK: get_global $push{{.+}}=, __stack_pointer@GLOBAL ; CHECK: try ; CHECK: call foo@FUNCTION ; CHECK: i32.catch $push{{.+}}=, 0 +; CHECK: set_global __stack_pointer@GLOBAL ; CHECK-DAG: i32.store __wasm_lpad_context ; CHECK-DAG: i32.store __wasm_lpad_context+4 ; CHECK: i32.call $push{{.+}}=, _Unwind_CallPersonality@FUNCTION @@ -63,6 +65,7 @@ try.cont: ; preds = %entry, %catch ; CHECK: try ; CHECK: call foo@FUNCTION ; CHECK: catch_all +; CHECK: set_global __stack_pointer@GLOBAL ; CHECK: i32.call $push{{.+}}=, _ZN7CleanupD1Ev@FUNCTION ; CHECK: rethrow ; CHECK: end_try @@ -161,10 +164,12 @@ terminate10: ; preds = %ehcleanup7 ; CHECK: call foo@FUNCTION ; CHECK: i32.catch ; CHECK-NOT: get_global $push{{.+}}=, __stack_pointer@GLOBAL +; CHECK: set_global __stack_pointer@GLOBAL ; CHECK: try ; CHECK: call foo@FUNCTION ; CHECK: catch_all ; CHECK-NOT: get_global $push{{.+}}=, __stack_pointer@GLOBAL +; CHECK: set_global __stack_pointer@GLOBAL ; CHECK: call __cxa_end_catch@FUNCTION ; CHECK-NOT: set_global __stack_pointer@GLOBAL, $pop{{.+}} ; CHECK: end_try