diff --git a/clang/lib/CodeGen/CGCXXClass.cpp b/clang/lib/CodeGen/CGCXXClass.cpp index 1fb8f8308c9e..b2ff2327afbe 100644 --- a/clang/lib/CodeGen/CGCXXClass.cpp +++ b/clang/lib/CodeGen/CGCXXClass.cpp @@ -81,6 +81,22 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, uint64_t Offset = ComputeBaseClassOffset(getContext(), ClassDecl, BaseClassDecl); + llvm::BasicBlock *CastNull = 0; + llvm::BasicBlock *CastNotNull = 0; + llvm::BasicBlock *CastEnd = 0; + + if (NullCheckValue) { + CastNull = createBasicBlock("cast.null"); + CastNotNull = createBasicBlock("cast.notnull"); + CastEnd = createBasicBlock("cast.end"); + + llvm::Value *IsNull = + Builder.CreateICmpEQ(BaseValue, + llvm::Constant::getNullValue(BaseValue->getType())); + Builder.CreateCondBr(IsNull, CastNull, CastNotNull); + EmitBlock(CastNotNull); + } + const llvm::Type *LongTy = CGM.getTypes().ConvertType(CGM.getContext().LongTy); const llvm::Type *Int8PtrTy = @@ -99,6 +115,20 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, // Cast back. const llvm::Type *BasePtr = llvm::PointerType::getUnqual(ConvertType(BTy)); BaseValue = Builder.CreateBitCast(BaseValue, BasePtr); + + if (NullCheckValue) { + Builder.CreateBr(CastEnd); + EmitBlock(CastNull); + Builder.CreateBr(CastEnd); + EmitBlock(CastEnd); + + llvm::PHINode *PHI = Builder.CreatePHI(BaseValue->getType()); + PHI->reserveOperandSpace(2); + PHI->addIncoming(BaseValue, CastNotNull); + PHI->addIncoming(llvm::Constant::getNullValue(BaseValue->getType()), + CastNull); + BaseValue = PHI; + } return BaseValue; } diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 7adbc9fd3b30..aee54464349a 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -669,8 +669,16 @@ Value *ScalarExprEmitter::EmitCastExpr(const Expr *E, QualType DestTy, CXXRecordDecl *BaseClassDecl = cast(BaseClassTy->getDecl()); Value *Src = Visit(const_cast(E)); + + // FIXME: This should be true, but that leads to a failure in virt.cpp + bool NullCheckValue = false; + + // We always assume that 'this' is never null. + if (isa(E)) + NullCheckValue = false; + return CGF.GetAddressCXXOfBaseClass(Src, DerivedClassDecl, BaseClassDecl, - /*NullCheckValue=*/true); + NullCheckValue); } }