diff --git a/clang/lib/CodeGen/CGCXX.cpp b/clang/lib/CodeGen/CGCXX.cpp index a119c5af932e..20b2bdcd5e18 100644 --- a/clang/lib/CodeGen/CGCXX.cpp +++ b/clang/lib/CodeGen/CGCXX.cpp @@ -782,8 +782,8 @@ public: return i->second; // FIXME: temporal botch, is this data here, by the time we need it? - // FIXME: Locate the containing virtual base first. - return 42; + assert(false && "FIXME: Locate the containing virtual base first"); + return 0; } bool OverrideMethod(const CXXMethodDecl *MD, llvm::Constant *m, @@ -888,18 +888,26 @@ public: const CXXRecordDecl *RD = i->first; int64_t Offset = i->second; for (method_iter mi = RD->method_begin(), me = RD->method_end(); mi != me; - ++mi) - if (mi->isVirtual()) { - const CXXMethodDecl *MD = *mi; + ++mi) { + if (!mi->isVirtual()) + continue; + + const CXXMethodDecl *MD = *mi; + llvm::Constant *m = 0; +// if (const CXXDestructorDecl *Dtor = dyn_cast(MD)) +// m = wrap(CGM.GetAddrOfCXXDestructor(Dtor, Dtor_Complete)); +// else { const FunctionProtoType *FPT = MD->getType()->getAs(); const llvm::Type *Ty = CGM.getTypes().GetFunctionType(CGM.getTypes().getFunctionInfo(MD), FPT->isVariadic()); - llvm::Constant *m = wrap(CGM.GetAddrOfFunction(MD, Ty)); - OverrideMethod(MD, m, MorallyVirtual, Offset); - } + m = wrap(CGM.GetAddrOfFunction(MD, Ty)); +// } + + OverrideMethod(MD, m, MorallyVirtual, Offset); + } } } @@ -1322,6 +1330,36 @@ llvm::Constant *CodeGenModule::BuildCovariantThunk(const CXXMethodDecl *MD, return m; } +llvm::Value * +CodeGenFunction::GetVirtualCXXBaseClassOffset(llvm::Value *This, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *BaseClassDecl) { + // FIXME: move to Context + if (vtableinfo == 0) + vtableinfo = new VtableInfo(CGM); + + const llvm::Type *Int8PtrTy = + llvm::Type::getInt8Ty(VMContext)->getPointerTo(); + + llvm::Value *VTablePtr = Builder.CreateBitCast(This, + Int8PtrTy->getPointerTo()); + VTablePtr = Builder.CreateLoad(VTablePtr, "vtable"); + + llvm::Value *VBaseOffsetPtr = + Builder.CreateConstGEP1_64(VTablePtr, + vtableinfo->VBlookup(ClassDecl, BaseClassDecl), + "vbase.offset.ptr"); + const llvm::Type *PtrDiffTy = + ConvertType(getContext().getPointerDiffType()); + + VBaseOffsetPtr = Builder.CreateBitCast(VBaseOffsetPtr, + PtrDiffTy->getPointerTo()); + + llvm::Value *VBaseOffset = Builder.CreateLoad(VBaseOffsetPtr, "vbase.offset"); + + return VBaseOffset; +} + llvm::Value * CodeGenFunction::BuildVirtualCall(const CXXMethodDecl *MD, llvm::Value *&This, const llvm::Type *Ty) { diff --git a/clang/lib/CodeGen/CGCXXClass.cpp b/clang/lib/CodeGen/CGCXXClass.cpp index 9c8174bc22f7..ff879f5786ff 100644 --- a/clang/lib/CodeGen/CGCXXClass.cpp +++ b/clang/lib/CodeGen/CGCXXClass.cpp @@ -12,61 +12,35 @@ //===----------------------------------------------------------------------===// #include "CodeGenFunction.h" +#include "clang/AST/CXXInheritance.h" #include "clang/AST/RecordLayout.h" + using namespace clang; using namespace CodeGen; -static bool -GetNestedPaths(llvm::SmallVectorImpl &NestedBasePaths, - const CXXRecordDecl *ClassDecl, - const CXXRecordDecl *BaseClassDecl) { - for (CXXRecordDecl::base_class_const_iterator i = ClassDecl->bases_begin(), - e = ClassDecl->bases_end(); i != e; ++i) { - if (i->isVirtual()) - continue; - const CXXRecordDecl *Base = - cast(i->getType()->getAs()->getDecl()); - if (Base == BaseClassDecl) { - NestedBasePaths.push_back(BaseClassDecl); - return true; - } - } - // BaseClassDecl not an immediate base of ClassDecl. - for (CXXRecordDecl::base_class_const_iterator i = ClassDecl->bases_begin(), - e = ClassDecl->bases_end(); i != e; ++i) { - if (i->isVirtual()) - continue; - const CXXRecordDecl *Base = - cast(i->getType()->getAs()->getDecl()); - if (GetNestedPaths(NestedBasePaths, Base, BaseClassDecl)) { - NestedBasePaths.push_back(Base); - return true; - } - } - return false; -} +static uint64_t +ComputeNonVirtualBaseClassOffset(ASTContext &Context, CXXBasePaths &Paths, + unsigned Start) { + uint64_t Offset = 0; -static uint64_t ComputeBaseClassOffset(ASTContext &Context, - const CXXRecordDecl *ClassDecl, - const CXXRecordDecl *BaseClassDecl) { - uint64_t Offset = 0; + const CXXBasePath &Path = Paths.front(); + for (unsigned i = Start, e = Path.size(); i != e; ++i) { + const CXXBasePathElement& Element = Path[i]; - llvm::SmallVector NestedBasePaths; - GetNestedPaths(NestedBasePaths, ClassDecl, BaseClassDecl); - assert(NestedBasePaths.size() > 0 && - "AddressCXXOfBaseClass - inheritence path failed"); - NestedBasePaths.push_back(ClassDecl); + // Get the layout. + const ASTRecordLayout &Layout = Context.getASTRecordLayout(Element.Class); - for (unsigned i = NestedBasePaths.size() - 1; i > 0; i--) { - const CXXRecordDecl *DerivedClass = NestedBasePaths[i]; - const CXXRecordDecl *BaseClass = NestedBasePaths[i-1]; - const ASTRecordLayout &Layout = - Context.getASTRecordLayout(DerivedClass); - - Offset += Layout.getBaseClassOffset(BaseClass) / 8; - } + const CXXBaseSpecifier *BS = Element.Base; + assert(!BS->isVirtual() && "Should not see virtual bases here!"); - return Offset; + const CXXRecordDecl *Base = + cast(BS->getType()->getAs()->getDecl()); + + // Add the offset. + Offset += Layout.getBaseClassOffset(Base) / 8; + } + + return Offset; } llvm::Constant * @@ -75,12 +49,15 @@ CodeGenModule::GetCXXBaseClassOffset(const CXXRecordDecl *ClassDecl, if (ClassDecl == BaseClassDecl) return 0; - QualType BTy = - getContext().getCanonicalType( - getContext().getTypeDeclType(const_cast(BaseClassDecl))); + CXXBasePaths Paths(/*FindAmbiguities=*/false, + /*RecordPaths=*/true, /*DetectVirtual=*/false); + if (!const_cast(ClassDecl)-> + isDerivedFrom(const_cast(BaseClassDecl), Paths)) { + assert(false && "Class must be derived from the passed in base class!"); + return 0; + } - uint64_t Offset = ComputeBaseClassOffset(getContext(), - ClassDecl, BaseClassDecl); + uint64_t Offset = ComputeNonVirtualBaseClassOffset(getContext(), Paths, 0); if (!Offset) return 0; @@ -90,19 +67,63 @@ CodeGenModule::GetCXXBaseClassOffset(const CXXRecordDecl *ClassDecl, return llvm::ConstantInt::get(PtrDiffTy, Offset); } +static llvm::Value *GetCXXBaseClassOffset(CodeGenFunction &CGF, + llvm::Value *BaseValue, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *BaseClassDecl) { + CXXBasePaths Paths(/*FindAmbiguities=*/false, + /*RecordPaths=*/true, /*DetectVirtual=*/true); + if (!const_cast(ClassDecl)-> + isDerivedFrom(const_cast(BaseClassDecl), Paths)) { + assert(false && "Class must be derived from the passed in base class!"); + return 0; + } + + unsigned Start = 0; + llvm::Value *VirtualOffset = 0; + if (const RecordType *RT = Paths.getDetectedVirtual()) { + const CXXRecordDecl *VBase = cast(RT->getDecl()); + + VirtualOffset = + CGF.GetVirtualCXXBaseClassOffset(BaseValue, ClassDecl, VBase); + + const CXXBasePath &Path = Paths.front(); + unsigned e = Path.size(); + for (Start = 0; Start != e; ++Start) { + const CXXBasePathElement& Element = Path[Start]; + + if (Element.Class == VBase) + break; + } + } + + uint64_t Offset = + ComputeNonVirtualBaseClassOffset(CGF.getContext(), Paths, Start); + + if (!Offset) + return VirtualOffset; + + const llvm::Type *PtrDiffTy = + CGF.ConvertType(CGF.getContext().getPointerDiffType()); + llvm::Value *NonVirtualOffset = llvm::ConstantInt::get(PtrDiffTy, Offset); + + if (VirtualOffset) + return CGF.Builder.CreateAdd(VirtualOffset, NonVirtualOffset); + + return NonVirtualOffset; +} + llvm::Value * CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, const CXXRecordDecl *ClassDecl, const CXXRecordDecl *BaseClassDecl, bool NullCheckValue) { - llvm::Constant *Offset = CGM.GetCXXBaseClassOffset(ClassDecl, BaseClassDecl); - QualType BTy = getContext().getCanonicalType( getContext().getTypeDeclType(const_cast(BaseClassDecl))); const llvm::Type *BasePtrTy = llvm::PointerType::getUnqual(ConvertType(BTy)); - if (!Offset) { + if (ClassDecl == BaseClassDecl) { // Just cast back. return Builder.CreateBitCast(BaseValue, BasePtrTy); } @@ -125,10 +146,15 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue, const llvm::Type *Int8PtrTy = llvm::PointerType::getUnqual(llvm::Type::getInt8Ty(VMContext)); + + llvm::Value *Offset = + GetCXXBaseClassOffset(*this, BaseValue, ClassDecl, BaseClassDecl); - // Apply the offset. - BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy); - BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr"); + if (Offset) { + // Apply the offset. + BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy); + BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr"); + } // Cast back. BaseValue = Builder.CreateBitCast(BaseValue, BasePtrTy); diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 34b3860a48f8..42de9fb62e68 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -589,6 +589,11 @@ public: const CXXRecordDecl *BaseClassDecl, bool NullCheckValue); + llvm::Value * + GetVirtualCXXBaseClassOffset(llvm::Value *This, + const CXXRecordDecl *ClassDecl, + const CXXRecordDecl *BaseClassDecl); + void EmitClassAggrMemberwiseCopy(llvm::Value *DestValue, llvm::Value *SrcValue, const ArrayType *Array, diff --git a/clang/test/CodeGenCXX/virtual-base-cast.cpp b/clang/test/CodeGenCXX/virtual-base-cast.cpp new file mode 100644 index 000000000000..9a728a82248c --- /dev/null +++ b/clang/test/CodeGenCXX/virtual-base-cast.cpp @@ -0,0 +1,9 @@ +// RUN: clang-cc -emit-llvm-only %s + +struct A { virtual ~A(); }; +struct B : A { virtual ~B(); }; +struct C : virtual B { virtual ~C(); }; + +void f(C *c) { + A* a = c; +} \ No newline at end of file