diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index a712a81ee71b..88c993d0f791 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -12798,7 +12798,7 @@ static void checkDeclInTargetContext(SourceLocation SL, SourceRange SR, Sema &SemaRef, Decl *D) { if (!D) return; - Decl *LD = nullptr; + const Decl *LD = nullptr; if (isa(D)) { LD = cast(D)->getDefinition(); } else if (isa(D)) { @@ -12814,22 +12814,29 @@ static void checkDeclInTargetContext(SourceLocation SL, SourceRange SR, ML->DeclarationMarkedOpenMPDeclareTarget(D, A); return; } - - } else if (isa(D)) { + } else if (auto *F = dyn_cast(D)) { const FunctionDecl *FD = nullptr; - if (cast(D)->hasBody(FD)) - LD = const_cast(FD); - - // If the definition is associated with the current declaration in the - // target region (it can be e.g. a lambda) that is legal and we do not need - // to do anything else. - if (LD == D) { - Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( - SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To); - D->addAttr(A); - if (ASTMutationListener *ML = SemaRef.Context.getASTMutationListener()) - ML->DeclarationMarkedOpenMPDeclareTarget(D, A); - return; + if (cast(D)->hasBody(FD)) { + LD = FD; + // If the definition is associated with the current declaration in the + // target region (it can be e.g. a lambda) that is legal and we do not + // need to do anything else. + if (LD == D) { + Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( + SemaRef.Context, OMPDeclareTargetDeclAttr::MT_To); + D->addAttr(A); + if (ASTMutationListener *ML = SemaRef.Context.getASTMutationListener()) + ML->DeclarationMarkedOpenMPDeclareTarget(D, A); + return; + } + } else if (F->isFunctionTemplateSpecialization() && + F->getTemplateSpecializationKind() == + TSK_ImplicitInstantiation) { + // Check if the function is implicitly instantiated from the template + // defined in the declare target region. + const FunctionTemplateDecl *FTD = F->getPrimaryTemplate(); + if (FTD && FTD->hasAttr()) + return; } } if (!LD) @@ -12841,7 +12848,7 @@ static void checkDeclInTargetContext(SourceLocation SL, SourceRange SR, SemaRef.Diag(LD->getLocation(), diag::warn_omp_not_in_target_context); SemaRef.Diag(SL, diag::note_used_here) << SR; } else { - DeclContext *DC = LD->getDeclContext(); + const DeclContext *DC = LD->getDeclContext(); while (DC) { if (isa(DC) && cast(DC)->hasAttr()) @@ -12894,7 +12901,8 @@ void Sema::checkDeclIsAllowedInOpenMPTarget(Expr *E, Decl *D, if ((E || !VD->getType()->isIncompleteType()) && !checkValueDeclInTarget(SL, SR, *this, DSAStack, VD)) { // Mark decl as declared target to prevent further diagnostic. - if (isa(VD) || isa(VD)) { + if (isa(VD) || isa(VD) || + isa(VD)) { Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( Context, OMPDeclareTargetDeclAttr::MT_To); VD->addAttr(A); @@ -12914,10 +12922,21 @@ void Sema::checkDeclIsAllowedInOpenMPTarget(Expr *E, Decl *D, return; } } + if (auto *FTD = dyn_cast(D)) { + if (FTD->hasAttr() && + (FTD->getAttr()->getMapType() == + OMPDeclareTargetDeclAttr::MT_Link)) { + assert(IdLoc.isValid() && "Source location is expected"); + Diag(IdLoc, diag::err_omp_function_in_link_clause); + Diag(FTD->getLocation(), diag::note_defined_here) << FTD; + return; + } + } if (!E) { // Checking declaration inside declare target region. if (!D->hasAttr() && - (isa(D) || isa(D))) { + (isa(D) || isa(D) || + isa(D))) { Attr *A = OMPDeclareTargetDeclAttr::CreateImplicit( Context, OMPDeclareTargetDeclAttr::MT_To); D->addAttr(A); diff --git a/clang/test/OpenMP/declare_target_messages.cpp b/clang/test/OpenMP/declare_target_messages.cpp index 4615dbdae487..3286a29dc418 100644 --- a/clang/test/OpenMP/declare_target_messages.cpp +++ b/clang/test/OpenMP/declare_target_messages.cpp @@ -33,6 +33,33 @@ struct NonT { typedef int sint; +template +T bla1() { return 0; } + +#pragma omp declare target +template +T bla2() { return 0; } +#pragma omp end declare target + +template<> +float bla2() { return 1.0; } + +#pragma omp declare target +void blub2() { + bla2(); + bla2(); +} +#pragma omp end declare target + +void t2() { +#pragma omp target + { + bla2(); + bla2(); + } +} + + #pragma omp declare target // expected-note {{to match this '#pragma omp declare target'}} #pragma omp threadprivate(a) // expected-note {{defined as threadprivate or thread local}} extern int b;