diff --git a/llvm/include/llvm/IR/Statepoint.h b/llvm/include/llvm/IR/Statepoint.h index 4e6fb4171a38..6cbffc39e706 100644 --- a/llvm/include/llvm/IR/Statepoint.h +++ b/llvm/include/llvm/IR/Statepoint.h @@ -54,8 +54,6 @@ bool isGCResult(const ImmutableCallSite &CS); /// concrete subtypes. This is structured analogous to CallSite /// rather than the IntrinsicInst.h helpers since we want to support /// invokable statepoints in the near future. -/// TODO: This does not currently allow the if(Statepoint S = ...) -/// idiom used with CallSites. Consider refactoring to support. template class StatepointBase { CallSiteTy StatepointCS; @@ -63,11 +61,15 @@ class StatepointBase { void *operator new(size_t s) = delete; protected: - explicit StatepointBase(InstructionTy *I) : StatepointCS(I) { - assert(isStatepoint(I)); + explicit StatepointBase(InstructionTy *I) { + if (isStatepoint(I)) { + StatepointCS = CallSiteTy(I); + assert(StatepointCS && "isStatepoint implies CallSite"); + } } - explicit StatepointBase(CallSiteTy CS) : StatepointCS(CS) { - assert(isStatepoint(CS)); + explicit StatepointBase(CallSiteTy CS) { + if (isStatepoint(CS)) + StatepointCS = CS; } public: @@ -82,23 +84,31 @@ public: CallArgsBeginPos = 5, }; + operator bool() const { + // We do not assign non-statepoint CallSites to StatepointCS. + return (bool)StatepointCS; + } + /// Return the underlying CallSite. - CallSiteTy getCallSite() { return StatepointCS; } + CallSiteTy getCallSite() const { + assert(*this && "check validity first!"); + return StatepointCS; + } uint64_t getFlags() const { - return cast(StatepointCS.getArgument(FlagsPos)) + return cast(getCallSite().getArgument(FlagsPos)) ->getZExtValue(); } /// Return the ID associated with this statepoint. uint64_t getID() { - const Value *IDVal = StatepointCS.getArgument(IDPos); + const Value *IDVal = getCallSite().getArgument(IDPos); return cast(IDVal)->getZExtValue(); } /// Return the number of patchable bytes associated with this statepoint. uint32_t getNumPatchBytes() { - const Value *NumPatchBytesVal = StatepointCS.getArgument(NumPatchBytesPos); + const Value *NumPatchBytesVal = getCallSite().getArgument(NumPatchBytesPos); uint64_t NumPatchBytes = cast(NumPatchBytesVal)->getZExtValue(); assert(isInt<32>(NumPatchBytes) && "should fit in 32 bits!"); @@ -107,7 +117,7 @@ public: /// Return the value actually being called or invoked. ValueTy *getActualCallee() { - return StatepointCS.getArgument(ActualCalleePos); + return getCallSite().getArgument(ActualCalleePos); } /// Return the type of the value returned by the call underlying the @@ -120,17 +130,17 @@ public: /// Number of arguments to be passed to the actual callee. int getNumCallArgs() { - const Value *NumCallArgsVal = StatepointCS.getArgument(NumCallArgsPos); + const Value *NumCallArgsVal = getCallSite().getArgument(NumCallArgsPos); return cast(NumCallArgsVal)->getZExtValue(); } typename CallSiteTy::arg_iterator call_args_begin() { - assert(CallArgsBeginPos <= (int)StatepointCS.arg_size()); - return StatepointCS.arg_begin() + CallArgsBeginPos; + assert(CallArgsBeginPos <= (int)getCallSite().arg_size()); + return getCallSite().arg_begin() + CallArgsBeginPos; } typename CallSiteTy::arg_iterator call_args_end() { auto I = call_args_begin() + getNumCallArgs(); - assert((StatepointCS.arg_end() - I) >= 0); + assert((getCallSite().arg_end() - I) >= 0); return I; } @@ -146,12 +156,12 @@ public: } typename CallSiteTy::arg_iterator gc_transition_args_begin() { auto I = call_args_end() + 1; - assert((StatepointCS.arg_end() - I) >= 0); + assert((getCallSite().arg_end() - I) >= 0); return I; } typename CallSiteTy::arg_iterator gc_transition_args_end() { auto I = gc_transition_args_begin() + getNumTotalGCTransitionArgs(); - assert((StatepointCS.arg_end() - I) >= 0); + assert((getCallSite().arg_end() - I) >= 0); return I; } @@ -170,12 +180,12 @@ public: typename CallSiteTy::arg_iterator vm_state_begin() { auto I = gc_transition_args_end() + 1; - assert((StatepointCS.arg_end() - I) >= 0); + assert((getCallSite().arg_end() - I) >= 0); return I; } typename CallSiteTy::arg_iterator vm_state_end() { auto I = vm_state_begin() + getNumTotalVMSArgs(); - assert((StatepointCS.arg_end() - I) >= 0); + assert((getCallSite().arg_end() - I) >= 0); return I; } @@ -186,7 +196,7 @@ public: typename CallSiteTy::arg_iterator gc_args_begin() { return vm_state_end(); } typename CallSiteTy::arg_iterator gc_args_end() { - return StatepointCS.arg_end(); + return getCallSite().arg_end(); } /// range adapter for gc arguments