[Clang][OpenMP] Allow loop-transformations with template parameters.

Clang would reject

    #pragma omp for
    #pragma omp tile sizes(P)
    for (int i = 0; i < 128; ++i) {}

where P is a template parameter, but the loop itself is not
template-dependent. Because P context-dependent, the TransformedStmt
cannot be generated and therefore is nullptr (until the template is
instantiated by TreeTransform). The OMPForDirective would still expect
the a loop is the dependent context and trigger an error.

Fix by introducing a NumGeneratedLoops field to OMPLoopTransformation.
This is used to distinguish the case where no TransformedStmt will be
generated at all (e.g. #pragma omp unroll full) and template
instantiation is needed. In the latter case, delay resolving the
iteration space like when the for-loop itself is template-dependent
until the template instatiation.

A more radical solution would always delay the iteration space analysis
until template instantiation, but would also break many test cases.

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D111124
This commit is contained in:
Michael Kruse 2021-10-06 11:43:29 -05:00
parent fdf4c03522
commit 2130117f92
7 changed files with 82 additions and 11 deletions

View File

@ -959,6 +959,9 @@ public:
class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
friend class ASTStmtReader;
/// Number of loops generated by this loop transformation.
unsigned NumGeneratedLoops = 0;
protected:
explicit OMPLoopTransformationDirective(StmtClass SC,
OpenMPDirectiveKind Kind,
@ -967,10 +970,16 @@ protected:
unsigned NumAssociatedLoops)
: OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
/// Set the number of loops generated by this loop transformation.
void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
public:
/// Return the number of associated (consumed) loops.
unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
/// Return the number of loops generated by this loop transformation.
unsigned getNumGeneratedLoops() { return NumGeneratedLoops; }
/// Get the de-sugared statements after after the loop transformation.
///
/// Might be nullptr if either the directive generates no loops and is handled
@ -5058,7 +5067,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective {
unsigned NumLoops)
: OMPLoopTransformationDirective(OMPTileDirectiveClass,
llvm::omp::OMPD_tile, StartLoc, EndLoc,
NumLoops) {}
NumLoops) {
setNumGeneratedLoops(3 * NumLoops);
}
void setPreInits(Stmt *PreInits) {
Data->getChildren()[PreInitsOffset] = PreInits;
@ -5163,7 +5174,7 @@ public:
static OMPUnrollDirective *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
Stmt *TransformedStmt, Stmt *PreInits);
unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
/// Build an empty '#pragma omp unroll' AST node for deserialization.
///

View File

@ -138,9 +138,18 @@ bool OMPLoopBasedDirective::doForAllLoops(
Stmt *TransformedStmt = Dir->getTransformedStmt();
if (!TransformedStmt) {
// May happen if the loop transformation does not result in a
// generated loop (such as full unrolling).
break;
unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
if (NumGeneratedLoops == 0) {
// May happen if the loop transformation does not result in a
// generated loop (such as full unrolling).
break;
}
if (NumGeneratedLoops > 0) {
// The loop transformation construct has generated loops, but these
// may not have been generated yet due to being in a dependent
// context.
return true;
}
}
CurStmt = TransformedStmt;
@ -419,10 +428,13 @@ OMPTileDirective *OMPTileDirective::CreateEmpty(const ASTContext &C,
OMPUnrollDirective *
OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
Stmt *AssociatedStmt, Stmt *TransformedStmt,
Stmt *PreInits) {
Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
Stmt *TransformedStmt, Stmt *PreInits) {
assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
auto *Dir = createDirective<OMPUnrollDirective>(
C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
Dir->setNumGeneratedLoops(NumGeneratedLoops);
Dir->setTransformedStmt(TransformedStmt);
Dir->setPreInits(PreInits);
return Dir;

View File

@ -12919,10 +12919,12 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
Body, OriginalInits))
return StmtError();
unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
// Delay unrolling to when template is completely instantiated.
if (CurContext->isDependentContext())
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
nullptr, nullptr);
NumGeneratedLoops, nullptr, nullptr);
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
@ -12941,9 +12943,9 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
// The generated loop may only be passed to other loop-associated directive
// when a partial clause is specified. Without the requirement it is
// sufficient to generate loop unroll metadata at code-generation.
if (!PartialClause)
if (NumGeneratedLoops == 0)
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
nullptr, nullptr);
NumGeneratedLoops, nullptr, nullptr);
// Otherwise, we need to provide a de-sugared/transformed AST that can be
// associated with another loop directive.
@ -13164,7 +13166,8 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
OuterFor, buildPreInits(Context, PreInits));
NumGeneratedLoops, OuterFor,
buildPreInits(Context, PreInits));
}
OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,

View File

@ -2327,6 +2327,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
void ASTStmtReader::VisitOMPLoopTransformationDirective(
OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
D->setNumGeneratedLoops(Record.readUInt32());
}
void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {

View File

@ -2226,6 +2226,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
void ASTStmtWriter::VisitOMPLoopTransformationDirective(
OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
Record.writeUInt32(D->getNumGeneratedLoops());
}
void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {

View File

@ -162,4 +162,25 @@ void tfoo6() {
}
// PRINT-LABEL: template <int Tile> void foo7(int start, int stop, int step) {
// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7
template <int Tile>
void foo7(int start, int stop, int step) {
// PRINT: #pragma omp tile sizes(Tile)
// DUMP: OMPTileDirective
// DUMP-NEXT: OMPSizesClause
// DUMP-NEXT: DeclRefExpr {{.*}} 'Tile' 'int'
#pragma omp tile sizes(Tile)
// PRINT-NEXT: for (int i = start; i < stop; i += step)
// DUMP-NEXT: ForStmt
for (int i = start; i < stop; i += step)
// PRINT-NEXT: body(i);
// DUMP: CallExpr
body(i);
}
void tfoo7() {
foo7<5>(0, 42, 2);
}
#endif

View File

@ -124,4 +124,26 @@ void unroll_template() {
unroll_templated<int,0,1024,1,4>();
}
// PRINT-LABEL: template <int Factor> void unroll_templated_factor(int start, int stop, int step) {
// DUMP-LABEL: FunctionTemplateDecl {{.*}} unroll_templated_factor
template <int Factor>
void unroll_templated_factor(int start, int stop, int step) {
// PRINT: #pragma omp unroll partial(Factor)
// DUMP: OMPUnrollDirective
// DUMP-NEXT: OMPPartialClause
// DUMP-NEXT: DeclRefExpr {{.*}} 'Factor' 'int'
#pragma omp unroll partial(Factor)
// PRINT-NEXT: for (int i = start; i < stop; i += step)
// DUMP-NEXT: ForStmt
for (int i = start; i < stop; i += step)
// PRINT-NEXT: body(i);
// DUMP: CallExpr
body(i);
}
void unroll_template_factor() {
unroll_templated_factor<4>(0, 42, 2);
}
#endif