diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index c51ad3e1dfa4..768690aa560f 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -307,6 +307,24 @@ def err_different_return_type_for_overriding_virtual_function : Error< def note_overridden_virtual_function : Note< "overridden virtual function is here">; +def err_covariant_return_inaccessible_base : Error< + "return type of virtual function %2 is not covariant with the return type " + "of the function it overrides " + "(conversion from %0 to inaccessible base class %1)">; +def err_covariant_return_ambiguous_derived_to_base_conv : Error< + "return type of virtual function %3 is not covariant with the return type of " + "the function it overrides (ambiguous conversion from derived class " + "%0 to base class %1:%2)">; +def err_covariant_return_not_derived : Error< + "return type of virtual function %0 is not covariant with the return type of " + "the function it overrrides (%1 is not derived from %2)">; +def err_covariant_return_type_different_qualifications : Error< + "return type of virtual function %0 is not covariant with the return type of " + "the function it overrides (%1 has different qualifiers than %2)">; +def err_covariant_return_type_class_type_more_qualified : Error< + "return type of virtual function %0 is not covariant with the return type of " + "the function it overrides (class type %1 is more qualified than class " + "type %2">; // C++ constructors def err_constructor_cannot_be : Error<"constructor cannot be declared '%0'">; def err_invalid_qualified_constructor : Error< diff --git a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp index bdd3cc2be14a..5a5c3c774b47 100644 --- a/clang/lib/Sema/SemaDeclCXX.cpp +++ b/clang/lib/Sema/SemaDeclCXX.cpp @@ -2701,12 +2701,71 @@ bool Sema::CheckOverridingFunctionReturnType(const CXXMethodDecl *New, CNewTy.getCVRQualifiers() == COldTy.getCVRQualifiers()) return false; - // FIXME: Check covariance. + // Check if the return types are covariant + QualType NewClassTy, OldClassTy; + + /// Both types must be pointers or references to classes. + if (PointerType *NewPT = dyn_cast(NewTy)) { + if (PointerType *OldPT = dyn_cast(OldTy)) { + NewClassTy = NewPT->getPointeeType(); + OldClassTy = OldPT->getPointeeType(); + } + } else if (ReferenceType *NewRT = dyn_cast(NewTy)) { + if (ReferenceType *OldRT = dyn_cast(OldTy)) { + NewClassTy = NewRT->getPointeeType(); + OldClassTy = OldRT->getPointeeType(); + } + } + + // The return types aren't either both pointers or references to a class type. + if (NewClassTy.isNull()) { + Diag(New->getLocation(), + diag::err_different_return_type_for_overriding_virtual_function) + << New->getDeclName() << NewTy << OldTy; + Diag(Old->getLocation(), diag::note_overridden_virtual_function); + + return true; + } - Diag(New->getLocation(), - diag::err_different_return_type_for_overriding_virtual_function) + if (NewClassTy.getUnqualifiedType() != OldClassTy.getUnqualifiedType()) { + // Check if the new class derives from the old class. + if (!IsDerivedFrom(NewClassTy, OldClassTy)) { + Diag(New->getLocation(), + diag::err_covariant_return_not_derived) + << New->getDeclName() << NewTy << OldTy; + Diag(Old->getLocation(), diag::note_overridden_virtual_function); + return true; + } + + // Check if we the conversion from derived to base is valid. + if (CheckDerivedToBaseConversion(NewClassTy, OldClassTy, + diag::err_covariant_return_inaccessible_base, + diag::err_covariant_return_ambiguous_derived_to_base_conv, + // FIXME: Should this point to the return type? + New->getLocation(), SourceRange(), New->getDeclName())) { + Diag(Old->getLocation(), diag::note_overridden_virtual_function); + return true; + } + } + + // The qualifiers of the return types must be the same. + if (CNewTy.getCVRQualifiers() != COldTy.getCVRQualifiers()) { + Diag(New->getLocation(), + diag::err_covariant_return_type_different_qualifications) << New->getDeclName() << NewTy << OldTy; - Diag(Old->getLocation(), diag::note_overridden_virtual_function); - - return true; + Diag(Old->getLocation(), diag::note_overridden_virtual_function); + return true; + }; + + + // The new class type must have the same or less qualifiers as the old type. + if (NewClassTy.isMoreQualifiedThan(OldClassTy)) { + Diag(New->getLocation(), + diag::err_covariant_return_type_class_type_more_qualified) + << New->getDeclName() << NewTy << OldTy; + Diag(Old->getLocation(), diag::note_overridden_virtual_function); + return true; + }; + + return false; } diff --git a/clang/test/SemaCXX/virtual-override.cpp b/clang/test/SemaCXX/virtual-override.cpp index c1b95ccbdf09..1a917fee0319 100644 --- a/clang/test/SemaCXX/virtual-override.cpp +++ b/clang/test/SemaCXX/virtual-override.cpp @@ -1,4 +1,4 @@ -// RUN: clang-cc -fsyntax-only -verify %s +// RUN: clang-cc -fsyntax-only -faccess-control -verify %s namespace T1 { @@ -11,3 +11,83 @@ class B : A { }; } + +namespace T2 { + +struct a { }; +struct b { }; + +class A { + virtual a* f(); // expected-note{{overridden virtual function is here}} +}; + +class B : A { + virtual b* f(); // expected-error{{return type of virtual function 'f' is not covariant with the return type of the function it overrrides ('struct T2::b *' is not derived from 'struct T2::a *')}} +}; + +} + +namespace T3 { + +struct a { }; +struct b : private a { }; // expected-note{{'private' inheritance specifier here}} + +class A { + virtual a* f(); // expected-note{{overridden virtual function is here}} +}; + +class B : A { + virtual b* f(); // expected-error{{return type of virtual function 'f' is not covariant with the return type of the function it overrides (conversion from 'struct T3::b' to inaccessible base class 'struct T3::a')}} +}; + +} + +namespace T4 { + +struct a { }; +struct a1 : a { }; +struct b : a, a1 { }; + +class A { + virtual a* f(); // expected-note{{overridden virtual function is here}} +}; + +class B : A { + virtual b* f(); // expected-error{{return type of virtual function 'f' is not covariant with the return type of the function it overrides (ambiguous conversion from derived class 'struct T4::b' to base class 'struct T4::a':\n\ + struct T4::b -> struct T4::a\n\ + struct T4::b -> struct T4::a1 -> struct T4::a)}} +}; + +} + +namespace T5 { + +struct a { }; + +class A { + virtual a* const f(); + virtual a* const g(); // expected-note{{overridden virtual function is here}} +}; + +class B : A { + virtual a* const f(); + virtual a* g(); // expected-error{{return type of virtual function 'g' is not covariant with the return type of the function it overrides ('struct T5::a *' has different qualifiers than 'struct T5::a *const')}} +}; + +} + +namespace T6 { + +struct a { }; + +class A { + virtual const a* f(); + virtual a* g(); // expected-note{{overridden virtual function is here}} +}; + +class B : A { + virtual a* f(); + virtual const a* g(); // expected-error{{return type of virtual function 'g' is not covariant with the return type of the function it overrides (class type 'struct T6::a const *' is more qualified than class type 'struct T6::a *'}} +}; + +}