[WebAssembly] Make Emscripten EH work with Emscripten SjLj

When Emscripten EH mixes with Emscripten SjLj, we are not currently
handling some of them correctly. There are three cases:
1. The current function calls `setjmp` and there is an `invoke` to a
   function that can either throw or longjmp. In this case, we have to
   check both for exception and longjmp. We are currently handling this
   case correctly:
   0c0eb76782/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp (L1058-L1090)
   When inserting routines for functions that can longjmp, which we do
   only for setjmp-calling functions, we check if the function was
   previously an `invoke` and handle it correctly.

2. The current function does NOT call `setjmp` and there is an `invoke`
   to a function that can either throw or longjmp. Because there is no
   `setjmp` call, we haven't been doing any check for functions that can
   longjmp. But in that case, for `invoke`, we only check for an
   exception and if it is not an exception we reset `__THREW__` to 0,
   which can silently swallow the longjmp:
   0c0eb76782/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp (L70-L80)
   This CL fixes this.

3. The current function calls `setjmp` and there is no `invoke`. Because
   it is not an `invoke`, we haven't been doing any check for functions
   that can throw, and only insert longjmp-checking routines for
   functions that can longjmp. But in that case, if a longjmpable
   function throws, we only check for a longjmp so if it is not a
   longjmp we reset `__THREW__` to 0, which can silently swallow the
   exception:
   0c0eb76782/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp (L156-L169)
   This CL fixes this.

To do that, this moves around some code, so we register necessary
functions for both EH and SjLj and precompute some data (the set of
functions that contains `setjmp`) before doing actual EH or SjLj
transformation.

This CL makes 2nd and 3rd tests in
https://github.com/emscripten-core/emscripten/pull/14732 work.

Reviewed By: dschuff

Differential Revision: https://reviews.llvm.org/D106525
This commit is contained in:
Heejin Ahn 2021-07-16 23:37:09 -07:00
parent 41b17c444d
commit c285a11efd
4 changed files with 249 additions and 63 deletions

View File

@ -25,7 +25,8 @@ class ModulePass;
class FunctionPass;
// LLVM IR passes.
ModulePass *createWebAssemblyLowerEmscriptenEHSjLj(bool DoEH, bool DoSjLj);
ModulePass *createWebAssemblyLowerEmscriptenEHSjLj(bool EnableEH,
bool EnableSjLj);
ModulePass *createWebAssemblyLowerGlobalDtors();
ModulePass *createWebAssemblyAddMissingPrototypes();
ModulePass *createWebAssemblyFixFunctionBitcasts();

View File

@ -216,6 +216,7 @@ namespace {
class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
bool EnableEH; // Enable exception handling
bool EnableSjLj; // Enable setjmp/longjmp handling
bool DoSjLj; // Whether we actually perform setjmp/longjmp handling
GlobalVariable *ThrewGV = nullptr;
GlobalVariable *ThrewValueGV = nullptr;
@ -234,6 +235,8 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
StringMap<Function *> InvokeWrappers;
// Set of allowed function names for exception handling
std::set<std::string> EHAllowlistSet;
// Functions that contains calls to setjmp
SmallPtrSet<Function *, 8> SetjmpUsers;
StringRef getPassName() const override {
return "WebAssembly Lower Emscripten Exceptions";
@ -252,6 +255,10 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
bool areAllExceptionsAllowed() const { return EHAllowlistSet.empty(); }
bool canLongjmp(Module &M, const Value *Callee) const;
bool isEmAsmCall(Module &M, const Value *Callee) const;
bool supportsException(const Function *F) const {
return EnableEH && (areAllExceptionsAllowed() ||
EHAllowlistSet.count(std::string(F->getName())));
}
void rebuildSSA(Function &F);
@ -287,7 +294,7 @@ static bool canThrow(const Value *V) {
return false;
StringRef Name = F->getName();
// leave setjmp and longjmp (mostly) alone, we process them properly later
if (Name == "setjmp" || Name == "longjmp")
if (Name == "setjmp" || Name == "longjmp" || Name == "emscripten_longjmp")
return false;
return !F->doesNotThrow();
}
@ -693,7 +700,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
Function *LongjmpF = M.getFunction("longjmp");
bool SetjmpUsed = SetjmpF && !SetjmpF->use_empty();
bool LongjmpUsed = LongjmpF && !LongjmpF->use_empty();
bool DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed);
DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed);
auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
assert(TPC && "Expected a TargetPassConfig");
@ -718,7 +725,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
bool Changed = false;
// Exception handling
// Function registration for exception handling
if (EnableEH) {
// Register __resumeException function
FunctionType *ResumeFTy =
@ -729,26 +736,15 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
FunctionType *EHTypeIDTy =
FunctionType::get(IRB.getInt32Ty(), IRB.getInt8PtrTy(), false);
EHTypeIDF = getEmscriptenFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M);
for (Function &F : M) {
if (F.isDeclaration())
continue;
Changed |= runEHOnFunction(F);
}
}
// Setjmp/longjmp handling
// Function registration and data pre-gathering for setjmp/longjmp handling
if (DoSjLj) {
Changed = true; // We have setjmp or longjmp somewhere
// Register emscripten_longjmp function
FunctionType *FTy = FunctionType::get(
IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false);
EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M);
if (LongjmpF)
replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF);
if (SetjmpF) {
// Register saveSetjmp function
FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
@ -765,16 +761,33 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
false);
TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M);
// Only traverse functions that uses setjmp in order not to insert
// unnecessary prep / cleanup code in every function
SmallPtrSet<Function *, 8> SetjmpUsers;
// Precompute setjmp users
for (User *U : SetjmpF->users()) {
auto *UI = cast<Instruction>(U);
SetjmpUsers.insert(UI->getFunction());
}
}
}
// Exception handling transformation
if (EnableEH) {
for (Function &F : M) {
if (F.isDeclaration())
continue;
Changed |= runEHOnFunction(F);
}
}
// Setjmp/longjmp handling transformation
if (DoSjLj) {
Changed = true; // We have setjmp or longjmp somewhere
if (LongjmpF)
replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF);
// Only traverse functions that uses setjmp in order not to insert
// unnecessary prep / cleanup code in every function
if (SetjmpF)
for (Function *F : SetjmpUsers)
runSjLjOnFunction(*F);
}
}
if (!Changed) {
@ -802,8 +815,6 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) {
bool Changed = false;
SmallVector<Instruction *, 64> ToErase;
SmallPtrSet<LandingPadInst *, 32> LandingPads;
bool AllowExceptions = areAllExceptionsAllowed() ||
EHAllowlistSet.count(std::string(F.getName()));
for (BasicBlock &BB : F) {
auto *II = dyn_cast<InvokeInst>(BB.getTerminator());
@ -813,12 +824,51 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) {
LandingPads.insert(II->getLandingPadInst());
IRB.SetInsertPoint(II);
bool NeedInvoke = AllowExceptions && canThrow(II->getCalledOperand());
const Value *Callee = II->getCalledOperand();
bool NeedInvoke = supportsException(&F) && canThrow(Callee);
if (NeedInvoke) {
// Wrap invoke with invoke wrapper and generate preamble/postamble
Value *Threw = wrapInvoke(II);
ToErase.push_back(II);
// If setjmp/longjmp handling is enabled, the thrown value can be not an
// exception but a longjmp. If the current function contains calls to
// setjmp, it will be appropriately handled in runSjLjOnFunction. But even
// if the function does not contain setjmp calls, we shouldn't silently
// ignore longjmps; we should rethrow them so they can be correctly
// handled in somewhere up the call chain where setjmp is.
// __THREW__'s value is 0 when nothing happened, 1 when an exception is
// thrown, other values when longjmp is thrown.
//
// if (%__THREW__.val == 0 || %__THREW__.val == 1)
// goto %tail
// else
// goto %longjmp.rethrow
//
// longjmp.rethrow: ;; This is longjmp. Rethrow it
// %__threwValue.val = __threwValue
// emscripten_longjmp(%__THREW__.val, %__threwValue.val);
//
// tail: ;; Nothing happened or an exception is thrown
// ... Continue exception handling ...
if (DoSjLj && !SetjmpUsers.count(&F) && canLongjmp(M, Callee)) {
BasicBlock *Tail = BasicBlock::Create(C, "tail", &F);
BasicBlock *RethrowBB = BasicBlock::Create(C, "longjmp.rethrow", &F);
Value *CmpEqOne =
IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one");
Value *CmpEqZero =
IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 0), "cmp.eq.zero");
Value *Or = IRB.CreateOr(CmpEqZero, CmpEqOne, "or");
IRB.CreateCondBr(Or, Tail, RethrowBB);
IRB.SetInsertPoint(RethrowBB);
Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV,
ThrewValueGV->getName() + ".val");
IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue});
IRB.CreateUnreachable();
IRB.SetInsertPoint(Tail);
}
// Insert a branch based on __THREW__ variable
Value *Cmp = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp");
IRB.CreateCondBr(Cmp, II->getUnwindDest(), II->getNormalDest());
@ -1098,6 +1148,46 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
Threw = wrapInvoke(CI);
ToErase.push_back(CI);
Tail = SplitBlock(BB, CI->getNextNode());
// If exception handling is enabled, the thrown value can be not a
// longjmp but an exception, in which case we shouldn't silently ignore
// exceptions; we should rethrow them.
// __THREW__'s value is 0 when nothing happened, 1 when an exception is
// thrown, other values when longjmp is thrown.
//
// if (%__THREW__.val == 1)
// goto %eh.rethrow
// else
// goto %normal
//
// eh.rethrow: ;; Rethrow exception
// %exn = call @__cxa_find_matching_catch_2() ;; Retrieve thrown ptr
// __resumeException(%exn)
//
// normal:
// <-- Insertion point. Will insert sjlj handling code from here
// goto %tail
//
// tail:
// ...
if (supportsException(&F) && canThrow(Callee)) {
IRB.SetInsertPoint(CI);
// We will add a new conditional branch. So remove the branch created
// when we split the BB
ToErase.push_back(BB->getTerminator());
BasicBlock *NormalBB = BasicBlock::Create(C, "normal", &F);
BasicBlock *RethrowBB = BasicBlock::Create(C, "eh.rethrow", &F);
Value *CmpEqOne =
IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one");
IRB.CreateCondBr(CmpEqOne, RethrowBB, NormalBB);
IRB.SetInsertPoint(RethrowBB);
CallInst *Exn = IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn");
IRB.CreateCall(ResumeF, {Exn});
IRB.CreateUnreachable();
IRB.SetInsertPoint(NormalBB);
IRB.CreateBr(Tail);
BB = NormalBB; // New insertion point to insert testSetjmp()
}
}
// We need to replace the terminator in Tail - SplitBlock makes BB go

View File

@ -0,0 +1,132 @@
; RUN: opt < %s -wasm-lower-em-ehsjlj -S | FileCheck %s
; RUN: llc < %s
; Tests for cases when exception handling and setjmp/longjmp handling are mixed.
target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
target triple = "wasm32-unknown-unknown"
%struct.__jmp_buf_tag = type { [6 x i32], i32, [32 x i32] }
; There is a function call (@foo) that can either throw an exception or longjmp
; and there is also a setjmp call. When @foo throws, we have to check both for
; exception and longjmp and jump to exception or longjmp handling BB depending
; on the result.
define void @setjmp_longjmp_exception() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
; CHECK-LABEL: @setjmp_longjmp_exception
entry:
%buf = alloca [1 x %struct.__jmp_buf_tag], align 16
%arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0
%call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0
invoke void @foo()
to label %try.cont unwind label %lpad
; CHECK: entry.split:
; CHECK: %[[CMP0:.*]] = icmp ne i32 %__THREW__.val, 0
; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue
; CHECK-NEXT: %[[CMP1:.*]] = icmp ne i32 %__threwValue.val, 0
; CHECK-NEXT: %[[CMP:.*]] = and i1 %[[CMP0]], %[[CMP1]]
; CHECK-NEXT: br i1 %[[CMP]], label %if.then1, label %if.else1
; This is exception checking part. %if.else1 leads here
; CHECK: entry.split.split:
; CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: br i1 %[[CMP]], label %lpad, label %try.cont
; longjmp checking part
; CHECK: if.then1:
; CHECK: call i32 @testSetjmp
lpad: ; preds = %entry
%0 = landingpad { i8*, i32 }
catch i8* null
%1 = extractvalue { i8*, i32 } %0, 0
%2 = extractvalue { i8*, i32 } %0, 1
%3 = call i8* @__cxa_begin_catch(i8* %1) #2
call void @__cxa_end_catch()
br label %try.cont
try.cont: ; preds = %entry, %lpad
ret void
}
; @foo can either throw an exception or longjmp. Because this function doesn't
; have any setjmp calls, we only handle exceptions in this function. But because
; sjlj is enabled, we check if the thrown value is longjmp and if so rethrow it
; by calling @emscripten_longjmp.
define void @rethrow_longjmp() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
; CHECK-LABEL: @rethrow_longjmp
entry:
invoke void @foo()
to label %try.cont unwind label %lpad
; CHECK: entry:
; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: %cmp.eq.zero = icmp eq i32 %__THREW__.val, 0
; CHECK-NEXT: %or = or i1 %cmp.eq.zero, %cmp.eq.one
; CHECK-NEXT: br i1 %or, label %tail, label %longjmp.rethrow
; CHECK: tail:
; CHECK-NEXT: %cmp = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: br i1 %cmp, label %lpad, label %try.cont
; CHECK: longjmp.rethrow:
; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue, align 4
; CHECK-NEXT: call void @emscripten_longjmp(i32 %__THREW__.val, i32 %__threwValue.val)
; CHECK-NEXT: unreachable
lpad: ; preds = %entry
%0 = landingpad { i8*, i32 }
catch i8* null
%1 = extractvalue { i8*, i32 } %0, 0
%2 = extractvalue { i8*, i32 } %0, 1
%3 = call i8* @__cxa_begin_catch(i8* %1) #5
call void @__cxa_end_catch()
br label %try.cont
try.cont: ; preds = %entry, %lpad
ret void
}
; This function contains a setjmp call and no invoke, so we only handle longjmp
; here. But @foo can also throw an exception, so we check if an exception is
; thrown and if so rethrow it by calling @__resumeException.
define void @rethrow_exception() {
; CHECK-LABEL: @rethrow_exception
entry:
%buf = alloca [1 x %struct.__jmp_buf_tag], align 16
%arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0
%call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0
%cmp = icmp ne i32 %call, 0
br i1 %cmp, label %return, label %if.end
if.end: ; preds = %entry
call void @foo()
br label %return
; CHECK: if.end:
; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: br i1 %cmp.eq.one, label %eh.rethrow, label %normal
; CHECK: normal:
; CHECK-NEXT: icmp ne i32 %__THREW__.val, 0
; CHECK: eh.rethrow:
; CHECK-NEXT: %exn = call i8* @__cxa_find_matching_catch_2()
; CHECK-NEXT: call void @__resumeException(i8* %exn)
; CHECK-NEXT: unreachable
return: ; preds = %entry, %if.end
ret void
}
declare void @foo()
; Function Attrs: returns_twice
declare i32 @setjmp(%struct.__jmp_buf_tag*)
; Function Attrs: noreturn
declare void @longjmp(%struct.__jmp_buf_tag*, i32)
declare i32 @__gxx_personality_v0(...)
declare i8* @__cxa_begin_catch(i8*)
declare void @__cxa_end_catch()
attributes #0 = { returns_twice }
attributes #1 = { noreturn }

View File

@ -100,44 +100,6 @@ entry:
; CHECK-NEXT: ret void
}
; Test a case when a function call is within try-catch, after a setjmp
define void @exception_and_longjmp() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
; CHECK-LABEL: @exception_and_longjmp
entry:
%buf = alloca [1 x %struct.__jmp_buf_tag], align 16
%arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0
%call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0
invoke void @foo()
to label %try.cont unwind label %lpad
; CHECK: entry.split:
; CHECK: store [[PTR]] 0, [[PTR]]* @__THREW__
; CHECK-NEXT: call cc{{.*}} void @__invoke_void(void ()* @foo)
; CHECK-NEXT: %[[__THREW__VAL:.*]] = load [[PTR]], [[PTR]]* @__THREW__
; CHECK-NEXT: store [[PTR]] 0, [[PTR]]* @__THREW__
; CHECK-NEXT: %[[CMP0:.*]] = icmp ne [[PTR]] %[[__THREW__VAL]], 0
; CHECK-NEXT: %[[THREWVALUE_VAL:.*]] = load i32, i32* @__threwValue
; CHECK-NEXT: %[[CMP1:.*]] = icmp ne i32 %[[THREWVALUE_VAL]], 0
; CHECK-NEXT: %[[CMP:.*]] = and i1 %[[CMP0]], %[[CMP1]]
; CHECK-NEXT: br i1 %[[CMP]], label %if.then1, label %if.else1
; CHECK: entry.split.split:
; CHECK-NEXT: %[[CMP:.*]] = icmp eq [[PTR]] %[[__THREW__VAL]], 1
; CHECK-NEXT: br i1 %[[CMP]], label %lpad, label %try.cont
lpad: ; preds = %entry
%0 = landingpad { i8*, i32 }
catch i8* null
%1 = extractvalue { i8*, i32 } %0, 0
%2 = extractvalue { i8*, i32 } %0, 1
%3 = call i8* @__cxa_begin_catch(i8* %1) #2
call void @__cxa_end_catch()
br label %try.cont
try.cont: ; preds = %entry, %lpad
ret void
}
; Test SSA validity
define void @ssa(i32 %n) {
; CHECK-LABEL: @ssa
@ -283,7 +245,8 @@ entry:
ret void
}
declare void @foo()
; Function Attrs: nounwind
declare void @foo() #2
; Function Attrs: returns_twice
declare i32 @setjmp(%struct.__jmp_buf_tag*) #0
; Function Attrs: noreturn
@ -311,7 +274,7 @@ attributes #3 = { allocsize(0) }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__resumeException" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="llvm_eh_typeid_for" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__invoke_void" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__cxa_find_matching_catch_3" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="__cxa_find_matching_catch_2" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="saveSetjmp" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="testSetjmp" }
; CHECK-DAG: attributes #{{[0-9]+}} = { "wasm-import-module"="env" "wasm-import-name"="emscripten_longjmp" }