forked from OSchip/llvm-project
[Clang][OpenMP] Allow loop-transformations with template parameters.
Clang would reject #pragma omp for #pragma omp tile sizes(P) for (int i = 0; i < 128; ++i) {} where P is a template parameter, but the loop itself is not template-dependent. Because P context-dependent, the TransformedStmt cannot be generated and therefore is nullptr (until the template is instantiated by TreeTransform). The OMPForDirective would still expect the a loop is the dependent context and trigger an error. Fix by introducing a NumGeneratedLoops field to OMPLoopTransformation. This is used to distinguish the case where no TransformedStmt will be generated at all (e.g. #pragma omp unroll full) and template instantiation is needed. In the latter case, delay resolving the iteration space like when the for-loop itself is template-dependent until the template instatiation. A more radical solution would always delay the iteration space analysis until template instantiation, but would also break many test cases. Reviewed By: ABataev Differential Revision: https://reviews.llvm.org/D111124
This commit is contained in:
parent
fdf4c03522
commit
2130117f92
|
@ -959,6 +959,9 @@ public:
|
|||
class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
|
||||
friend class ASTStmtReader;
|
||||
|
||||
/// Number of loops generated by this loop transformation.
|
||||
unsigned NumGeneratedLoops = 0;
|
||||
|
||||
protected:
|
||||
explicit OMPLoopTransformationDirective(StmtClass SC,
|
||||
OpenMPDirectiveKind Kind,
|
||||
|
@ -967,10 +970,16 @@ protected:
|
|||
unsigned NumAssociatedLoops)
|
||||
: OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
|
||||
|
||||
/// Set the number of loops generated by this loop transformation.
|
||||
void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
|
||||
|
||||
public:
|
||||
/// Return the number of associated (consumed) loops.
|
||||
unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
|
||||
|
||||
/// Return the number of loops generated by this loop transformation.
|
||||
unsigned getNumGeneratedLoops() { return NumGeneratedLoops; }
|
||||
|
||||
/// Get the de-sugared statements after after the loop transformation.
|
||||
///
|
||||
/// Might be nullptr if either the directive generates no loops and is handled
|
||||
|
@ -5058,7 +5067,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective {
|
|||
unsigned NumLoops)
|
||||
: OMPLoopTransformationDirective(OMPTileDirectiveClass,
|
||||
llvm::omp::OMPD_tile, StartLoc, EndLoc,
|
||||
NumLoops) {}
|
||||
NumLoops) {
|
||||
setNumGeneratedLoops(3 * NumLoops);
|
||||
}
|
||||
|
||||
void setPreInits(Stmt *PreInits) {
|
||||
Data->getChildren()[PreInitsOffset] = PreInits;
|
||||
|
@ -5163,7 +5174,7 @@ public:
|
|||
static OMPUnrollDirective *
|
||||
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
|
||||
ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
|
||||
Stmt *TransformedStmt, Stmt *PreInits);
|
||||
unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
|
||||
|
||||
/// Build an empty '#pragma omp unroll' AST node for deserialization.
|
||||
///
|
||||
|
|
|
@ -138,9 +138,18 @@ bool OMPLoopBasedDirective::doForAllLoops(
|
|||
|
||||
Stmt *TransformedStmt = Dir->getTransformedStmt();
|
||||
if (!TransformedStmt) {
|
||||
// May happen if the loop transformation does not result in a
|
||||
// generated loop (such as full unrolling).
|
||||
break;
|
||||
unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
|
||||
if (NumGeneratedLoops == 0) {
|
||||
// May happen if the loop transformation does not result in a
|
||||
// generated loop (such as full unrolling).
|
||||
break;
|
||||
}
|
||||
if (NumGeneratedLoops > 0) {
|
||||
// The loop transformation construct has generated loops, but these
|
||||
// may not have been generated yet due to being in a dependent
|
||||
// context.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
CurStmt = TransformedStmt;
|
||||
|
@ -419,10 +428,13 @@ OMPTileDirective *OMPTileDirective::CreateEmpty(const ASTContext &C,
|
|||
OMPUnrollDirective *
|
||||
OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
|
||||
SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
|
||||
Stmt *AssociatedStmt, Stmt *TransformedStmt,
|
||||
Stmt *PreInits) {
|
||||
Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
|
||||
Stmt *TransformedStmt, Stmt *PreInits) {
|
||||
assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
|
||||
|
||||
auto *Dir = createDirective<OMPUnrollDirective>(
|
||||
C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
|
||||
Dir->setNumGeneratedLoops(NumGeneratedLoops);
|
||||
Dir->setTransformedStmt(TransformedStmt);
|
||||
Dir->setPreInits(PreInits);
|
||||
return Dir;
|
||||
|
|
|
@ -12919,10 +12919,12 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
|
|||
Body, OriginalInits))
|
||||
return StmtError();
|
||||
|
||||
unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
|
||||
|
||||
// Delay unrolling to when template is completely instantiated.
|
||||
if (CurContext->isDependentContext())
|
||||
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
|
||||
nullptr, nullptr);
|
||||
NumGeneratedLoops, nullptr, nullptr);
|
||||
|
||||
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
|
||||
|
||||
|
@ -12941,9 +12943,9 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
|
|||
// The generated loop may only be passed to other loop-associated directive
|
||||
// when a partial clause is specified. Without the requirement it is
|
||||
// sufficient to generate loop unroll metadata at code-generation.
|
||||
if (!PartialClause)
|
||||
if (NumGeneratedLoops == 0)
|
||||
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
|
||||
nullptr, nullptr);
|
||||
NumGeneratedLoops, nullptr, nullptr);
|
||||
|
||||
// Otherwise, we need to provide a de-sugared/transformed AST that can be
|
||||
// associated with another loop directive.
|
||||
|
@ -13164,7 +13166,8 @@ StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
|
|||
LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
|
||||
|
||||
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
|
||||
OuterFor, buildPreInits(Context, PreInits));
|
||||
NumGeneratedLoops, OuterFor,
|
||||
buildPreInits(Context, PreInits));
|
||||
}
|
||||
|
||||
OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,
|
||||
|
|
|
@ -2327,6 +2327,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
|
|||
void ASTStmtReader::VisitOMPLoopTransformationDirective(
|
||||
OMPLoopTransformationDirective *D) {
|
||||
VisitOMPLoopBasedDirective(D);
|
||||
D->setNumGeneratedLoops(Record.readUInt32());
|
||||
}
|
||||
|
||||
void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
|
||||
|
|
|
@ -2226,6 +2226,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
|
|||
void ASTStmtWriter::VisitOMPLoopTransformationDirective(
|
||||
OMPLoopTransformationDirective *D) {
|
||||
VisitOMPLoopBasedDirective(D);
|
||||
Record.writeUInt32(D->getNumGeneratedLoops());
|
||||
}
|
||||
|
||||
void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
|
||||
|
|
|
@ -162,4 +162,25 @@ void tfoo6() {
|
|||
}
|
||||
|
||||
|
||||
// PRINT-LABEL: template <int Tile> void foo7(int start, int stop, int step) {
|
||||
// DUMP-LABEL: FunctionTemplateDecl {{.*}} foo7
|
||||
template <int Tile>
|
||||
void foo7(int start, int stop, int step) {
|
||||
// PRINT: #pragma omp tile sizes(Tile)
|
||||
// DUMP: OMPTileDirective
|
||||
// DUMP-NEXT: OMPSizesClause
|
||||
// DUMP-NEXT: DeclRefExpr {{.*}} 'Tile' 'int'
|
||||
#pragma omp tile sizes(Tile)
|
||||
// PRINT-NEXT: for (int i = start; i < stop; i += step)
|
||||
// DUMP-NEXT: ForStmt
|
||||
for (int i = start; i < stop; i += step)
|
||||
// PRINT-NEXT: body(i);
|
||||
// DUMP: CallExpr
|
||||
body(i);
|
||||
}
|
||||
void tfoo7() {
|
||||
foo7<5>(0, 42, 2);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
|
|
@ -124,4 +124,26 @@ void unroll_template() {
|
|||
unroll_templated<int,0,1024,1,4>();
|
||||
}
|
||||
|
||||
|
||||
// PRINT-LABEL: template <int Factor> void unroll_templated_factor(int start, int stop, int step) {
|
||||
// DUMP-LABEL: FunctionTemplateDecl {{.*}} unroll_templated_factor
|
||||
template <int Factor>
|
||||
void unroll_templated_factor(int start, int stop, int step) {
|
||||
// PRINT: #pragma omp unroll partial(Factor)
|
||||
// DUMP: OMPUnrollDirective
|
||||
// DUMP-NEXT: OMPPartialClause
|
||||
// DUMP-NEXT: DeclRefExpr {{.*}} 'Factor' 'int'
|
||||
#pragma omp unroll partial(Factor)
|
||||
// PRINT-NEXT: for (int i = start; i < stop; i += step)
|
||||
// DUMP-NEXT: ForStmt
|
||||
for (int i = start; i < stop; i += step)
|
||||
// PRINT-NEXT: body(i);
|
||||
// DUMP: CallExpr
|
||||
body(i);
|
||||
}
|
||||
void unroll_template_factor() {
|
||||
unroll_templated_factor<4>(0, 42, 2);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue