diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h index 0560692f3d5b..8531609bb9e0 100644 --- a/clang/include/clang/Sema/HLSLExternalSemaSource.h +++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h @@ -22,7 +22,7 @@ class Sema; class HLSLExternalSemaSource : public ExternalSemaSource { Sema *SemaPtr = nullptr; - NamespaceDecl *HLSLNamespace; + NamespaceDecl *HLSLNamespace = nullptr; CXXRecordDecl *ResourceDecl; using CompletionFunction = std::function; diff --git a/clang/lib/Frontend/FrontendAction.cpp b/clang/lib/Frontend/FrontendAction.cpp index 78c8de78d7ab..b541c59fb9c4 100644 --- a/clang/lib/Frontend/FrontendAction.cpp +++ b/clang/lib/Frontend/FrontendAction.cpp @@ -28,6 +28,7 @@ #include "clang/Lex/PreprocessorOptions.h" #include "clang/Parse/ParseAST.h" #include "clang/Sema/HLSLExternalSemaSource.h" +#include "clang/Sema/MultiplexExternalSemaSource.h" #include "clang/Serialization/ASTDeserializationListener.h" #include "clang/Serialization/ASTReader.h" #include "clang/Serialization/GlobalModuleIndex.h" @@ -1026,9 +1027,15 @@ bool FrontendAction::BeginSourceFile(CompilerInstance &CI, // Setup HLSL External Sema Source if (CI.getLangOpts().HLSL && CI.hasASTContext()) { - IntrusiveRefCntPtr HLSLSema( + IntrusiveRefCntPtr HLSLSema( new HLSLExternalSemaSource()); - CI.getASTContext().setExternalSource(HLSLSema); + if (auto *SemaSource = dyn_cast_if_present( + CI.getASTContext().getExternalSource())) { + IntrusiveRefCntPtr MultiSema( + new MultiplexExternalSemaSource(SemaSource, HLSLSema.get())); + CI.getASTContext().setExternalSource(MultiSema); + } else + CI.getASTContext().setExternalSource(HLSLSema); } FailureCleanup.release(); diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index 969d65997cdc..681154d52cb3 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -30,6 +30,7 @@ struct TemplateParameterListBuilder; struct BuiltinTypeDeclBuilder { CXXRecordDecl *Record = nullptr; ClassTemplateDecl *Template = nullptr; + ClassTemplateDecl *PrevTemplate = nullptr; NamespaceDecl *HLSLNamespace = nullptr; llvm::StringMap Fields; @@ -43,48 +44,46 @@ struct BuiltinTypeDeclBuilder { ASTContext &AST = S.getASTContext(); IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier); + LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName); + CXXRecordDecl *PrevDecl = nullptr; + if (S.LookupQualifiedName(Result, HLSLNamespace)) { + NamedDecl *Found = Result.getFoundDecl(); + if (auto *TD = dyn_cast(Found)) { + PrevDecl = TD->getTemplatedDecl(); + PrevTemplate = TD; + } else + PrevDecl = dyn_cast(Found); + assert(PrevDecl && "Unexpected lookup result type."); + } + + if (PrevDecl && PrevDecl->isCompleteDefinition()) { + Record = PrevDecl; + return; + } + Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class, HLSLNamespace, SourceLocation(), - SourceLocation(), &II, nullptr, true); + SourceLocation(), &II, PrevDecl, true); Record->setImplicit(true); Record->setLexicalDeclContext(HLSLNamespace); Record->setHasExternalLexicalStorage(); - // Don't let anyone derive from built-in types + // Don't let anyone derive from built-in types. Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(), AttributeCommonInfo::AS_Keyword, FinalAttr::Keyword_final)); } ~BuiltinTypeDeclBuilder() { - if (HLSLNamespace && !Template) + if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace) HLSLNamespace->addDecl(Record); } - BuiltinTypeDeclBuilder & - addTemplateArgumentList(llvm::ArrayRef TemplateArgs) { - ASTContext &AST = Record->getASTContext(); - - auto *ParamList = - TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), - TemplateArgs, SourceLocation(), nullptr); - Template = ClassTemplateDecl::Create( - AST, Record->getDeclContext(), SourceLocation(), - DeclarationName(Record->getIdentifier()), ParamList, Record); - Record->setDescribedClassTemplate(Template); - Template->setImplicit(true); - Template->setLexicalDeclContext(Record->getDeclContext()); - Record->getDeclContext()->addDecl(Template); - - // Requesting the class name specialization will fault in required types. - QualType T = Template->getInjectedClassNameSpecialization(); - T = AST.getInjectedClassNameType(Record, T); - return *this; - } - BuiltinTypeDeclBuilder & addMemberVariable(StringRef Name, QualType Type, AccessSpecifier Access = AccessSpecifier::AS_private) { + if (Record->isCompleteDefinition()) + return *this; assert(Record->isBeingDefined() && "Definition must be started before adding members!"); ASTContext &AST = Record->getASTContext(); @@ -104,6 +103,8 @@ struct BuiltinTypeDeclBuilder { BuiltinTypeDeclBuilder & addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) { + if (Record->isCompleteDefinition()) + return *this; QualType Ty = Record->getASTContext().VoidPtrTy; if (Template) { if (const auto *TTD = dyn_cast( @@ -116,6 +117,8 @@ struct BuiltinTypeDeclBuilder { BuiltinTypeDeclBuilder & annotateResourceClass(HLSLResourceAttr::ResourceClass RC) { + if (Record->isCompleteDefinition()) + return *this; Record->addAttr( HLSLResourceAttr::CreateImplicit(Record->getASTContext(), RC)); return *this; @@ -147,6 +150,8 @@ struct BuiltinTypeDeclBuilder { BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S, ResourceClass RC) { + if (Record->isCompleteDefinition()) + return *this; ASTContext &AST = Record->getASTContext(); QualType ConstructorType = @@ -197,12 +202,16 @@ struct BuiltinTypeDeclBuilder { } BuiltinTypeDeclBuilder &addArraySubscriptOperators() { + if (Record->isCompleteDefinition()) + return *this; addArraySubscriptOperator(true); addArraySubscriptOperator(false); return *this; } BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) { + if (Record->isCompleteDefinition()) + return *this; assert(Fields.count("h") > 0 && "Subscript operator must be added after the handle."); @@ -279,11 +288,15 @@ struct BuiltinTypeDeclBuilder { } BuiltinTypeDeclBuilder &startDefinition() { + if (Record->isCompleteDefinition()) + return *this; Record->startDefinition(); return *this; } BuiltinTypeDeclBuilder &completeDefinition() { + if (Record->isCompleteDefinition()) + return *this; assert(Record->isBeingDefined() && "Definition must be started before completing it."); @@ -306,6 +319,8 @@ struct TemplateParameterListBuilder { TemplateParameterListBuilder & addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) { + if (Builder.Record->isCompleteDefinition()) + return *this; unsigned Position = static_cast(Params.size()); auto *Decl = TemplateTypeParmDecl::Create( AST, Builder.Record->getDeclContext(), SourceLocation(), @@ -332,6 +347,9 @@ struct TemplateParameterListBuilder { Builder.Record->setDescribedClassTemplate(Builder.Template); Builder.Template->setImplicit(true); Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext()); + // NOTE: setPreviousDecl before addDecl so new decl replace old decl when + // make visible. + Builder.Template->setPreviousDecl(Builder.PrevTemplate); Builder.Record->getDeclContext()->addDecl(Builder.Template); Params.clear(); @@ -352,12 +370,24 @@ HLSLExternalSemaSource::~HLSLExternalSemaSource() {} void HLSLExternalSemaSource::InitializeSema(Sema &S) { SemaPtr = &S; ASTContext &AST = SemaPtr->getASTContext(); + // If the translation unit has external storage force external decls to load. + if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage()) + (void)AST.getTranslationUnitDecl()->decls_begin(); + IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier); - HLSLNamespace = - NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false, - SourceLocation(), SourceLocation(), &HLSL, nullptr); + LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName); + NamespaceDecl *PrevDecl = nullptr; + if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl())) + PrevDecl = Result.getAsSingle(); + HLSLNamespace = NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), + false, SourceLocation(), + SourceLocation(), &HLSL, PrevDecl); HLSLNamespace->setImplicit(true); + HLSLNamespace->setHasExternalLexicalStorage(); AST.getTranslationUnitDecl()->addDecl(HLSLNamespace); + + // Force external decls in the HLSL namespace to load from the PCH. + (void)HLSLNamespace->getCanonicalDecl()->decls_begin(); defineTrivialHLSLTypes(); forwardDeclareHLSLTypes(); @@ -443,9 +473,11 @@ void HLSLExternalSemaSource::forwardDeclareHLSLTypes() { .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy) .finalizeTemplateArgs() .Record; - Completions.insert(std::make_pair( - Decl, std::bind(&HLSLExternalSemaSource::completeBufferType, this, - std::placeholders::_1))); + if (!Decl->isCompleteDefinition()) + Completions.insert( + std::make_pair(Decl->getCanonicalDecl(), + std::bind(&HLSLExternalSemaSource::completeBufferType, + this, std::placeholders::_1))); } void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) { @@ -457,6 +489,7 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) { // declaration and complete that. if (auto TDecl = dyn_cast(Record)) Record = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + Record = Record->getCanonicalDecl(); auto It = Completions.find(Record); if (It == Completions.end()) return; diff --git a/clang/test/AST/HLSL/Inputs/pch.hlsl b/clang/test/AST/HLSL/Inputs/pch.hlsl new file mode 100644 index 000000000000..90d7a19ae424 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch.hlsl @@ -0,0 +1,4 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} diff --git a/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl new file mode 100644 index 000000000000..bca281638a5b --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl @@ -0,0 +1,6 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} + +RWBuffer Buf; diff --git a/clang/test/AST/HLSL/pch.hlsl b/clang/test/AST/HLSL/pch.hlsl new file mode 100644 index 000000000000..74254fedf27c --- /dev/null +++ b/clang/test/AST/HLSL/pch.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s \ +// RUN: | FileCheck %s + +// Make sure PCH works by using function declared in PCH header and declare a RWBuffer in current file. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:23> col:23 Buffer 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buffer; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/AST/HLSL/pch_with_buf.hlsl b/clang/test/AST/HLSL/pch_with_buf.hlsl new file mode 100644 index 000000000000..4e657606cbcb --- /dev/null +++ b/clang/test/AST/HLSL/pch_with_buf.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -finclude-default-header -emit-pch -o %t %S/Inputs/pch_with_buf.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s | FileCheck %s + +// Make sure PCH works by using function declared in PCH header. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// Make sure buffer defined in PCH works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} col:17 imported Buf 'RWBuffer':'hlsl::RWBuffer<>' +// Make sure declare a RWBuffer in current file works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:11:1, col:23> col:23 Buf2 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buf2; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +}