[OPENMP] Add support for cancel constructs in `target teams distribute

parallel for`.

Add support for cancel/cancellation point directives inside `target
teams distribute parallel for` directives.

llvm-svn: 318881
This commit is contained in:
Alexey Bataev 2017-11-22 21:12:03 +00:00
parent fd872e9637
commit 16e798873e
7 changed files with 38 additions and 10 deletions

View File

@ -957,8 +957,13 @@ public:
T->getStmtClass() == OMPTargetSimdDirectiveClass || T->getStmtClass() == OMPTargetSimdDirectiveClass ||
T->getStmtClass() == OMPTeamsDistributeDirectiveClass || T->getStmtClass() == OMPTeamsDistributeDirectiveClass ||
T->getStmtClass() == OMPTeamsDistributeSimdDirectiveClass || T->getStmtClass() == OMPTeamsDistributeSimdDirectiveClass ||
T->getStmtClass() == OMPTeamsDistributeParallelForSimdDirectiveClass || T->getStmtClass() ==
T->getStmtClass() == OMPTeamsDistributeParallelForDirectiveClass; OMPTeamsDistributeParallelForSimdDirectiveClass ||
T->getStmtClass() == OMPTeamsDistributeParallelForDirectiveClass ||
T->getStmtClass() ==
OMPTargetTeamsDistributeParallelForDirectiveClass ||
T->getStmtClass() ==
OMPTargetTeamsDistributeParallelForSimdDirectiveClass;
} }
}; };
@ -3799,6 +3804,8 @@ public:
class OMPTargetTeamsDistributeParallelForDirective final class OMPTargetTeamsDistributeParallelForDirective final
: public OMPLoopDirective { : public OMPLoopDirective {
friend class ASTStmtReader; friend class ASTStmtReader;
/// true if the construct has inner cancel directive.
bool HasCancel = false;
/// Build directive with the given start and end location. /// Build directive with the given start and end location.
/// ///
@ -3814,7 +3821,8 @@ class OMPTargetTeamsDistributeParallelForDirective final
: OMPLoopDirective(this, : OMPLoopDirective(this,
OMPTargetTeamsDistributeParallelForDirectiveClass, OMPTargetTeamsDistributeParallelForDirectiveClass,
OMPD_target_teams_distribute_parallel_for, StartLoc, OMPD_target_teams_distribute_parallel_for, StartLoc,
EndLoc, CollapsedNum, NumClauses) {} EndLoc, CollapsedNum, NumClauses),
HasCancel(false) {}
/// Build an empty directive. /// Build an empty directive.
/// ///
@ -3826,7 +3834,11 @@ class OMPTargetTeamsDistributeParallelForDirective final
: OMPLoopDirective( : OMPLoopDirective(
this, OMPTargetTeamsDistributeParallelForDirectiveClass, this, OMPTargetTeamsDistributeParallelForDirectiveClass,
OMPD_target_teams_distribute_parallel_for, SourceLocation(), OMPD_target_teams_distribute_parallel_for, SourceLocation(),
SourceLocation(), CollapsedNum, NumClauses) {} SourceLocation(), CollapsedNum, NumClauses),
HasCancel(false) {}
/// Set cancel state.
void setHasCancel(bool Has) { HasCancel = Has; }
public: public:
/// Creates directive with a list of \a Clauses. /// Creates directive with a list of \a Clauses.
@ -3838,11 +3850,12 @@ public:
/// \param Clauses List of clauses. /// \param Clauses List of clauses.
/// \param AssociatedStmt Statement, associated with the directive. /// \param AssociatedStmt Statement, associated with the directive.
/// \param Exprs Helper expressions for CodeGen. /// \param Exprs Helper expressions for CodeGen.
/// \param HasCancel true if this directive has inner cancel directive.
/// ///
static OMPTargetTeamsDistributeParallelForDirective * static OMPTargetTeamsDistributeParallelForDirective *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses,
Stmt *AssociatedStmt, const HelperExprs &Exprs); Stmt *AssociatedStmt, const HelperExprs &Exprs, bool HasCancel);
/// Creates an empty directive with the place for \a NumClauses clauses. /// Creates an empty directive with the place for \a NumClauses clauses.
/// ///
@ -3854,6 +3867,9 @@ public:
CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum,
EmptyShell); EmptyShell);
/// Return true if current directive has inner cancel directive.
bool hasCancel() const { return HasCancel; }
static bool classof(const Stmt *T) { static bool classof(const Stmt *T) {
return T->getStmtClass() == return T->getStmtClass() ==
OMPTargetTeamsDistributeParallelForDirectiveClass; OMPTargetTeamsDistributeParallelForDirectiveClass;

View File

@ -1624,7 +1624,7 @@ OMPTargetTeamsDistributeParallelForDirective *
OMPTargetTeamsDistributeParallelForDirective::Create( OMPTargetTeamsDistributeParallelForDirective::Create(
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
const HelperExprs &Exprs) { const HelperExprs &Exprs, bool HasCancel) {
auto Size = auto Size =
llvm::alignTo(sizeof(OMPTargetTeamsDistributeParallelForDirective), llvm::alignTo(sizeof(OMPTargetTeamsDistributeParallelForDirective),
alignof(OMPClause *)); alignof(OMPClause *));
@ -1670,6 +1670,7 @@ OMPTargetTeamsDistributeParallelForDirective::Create(
Dir->setCombinedCond(Exprs.DistCombinedFields.Cond); Dir->setCombinedCond(Exprs.DistCombinedFields.Cond);
Dir->setCombinedNextLowerBound(Exprs.DistCombinedFields.NLB); Dir->setCombinedNextLowerBound(Exprs.DistCombinedFields.NLB);
Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB); Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB);
Dir->HasCancel = HasCancel;
return Dir; return Dir;
} }

View File

@ -2014,6 +2014,9 @@ emitInnerParallelForWhenCombined(CodeGenFunction &CGF,
HasCancel = D->hasCancel(); HasCancel = D->hasCancel();
else if (const auto *D = dyn_cast<OMPDistributeParallelForDirective>(&S)) else if (const auto *D = dyn_cast<OMPDistributeParallelForDirective>(&S))
HasCancel = D->hasCancel(); HasCancel = D->hasCancel();
else if (const auto *D =
dyn_cast<OMPTargetTeamsDistributeParallelForDirective>(&S))
HasCancel = D->hasCancel();
} }
CodeGenFunction::OMPCancelStackRAII CancelRegion(CGF, S.getDirectiveKind(), CodeGenFunction::OMPCancelStackRAII CancelRegion(CGF, S.getDirectiveKind(),
HasCancel); HasCancel);
@ -3949,7 +3952,8 @@ CodeGenFunction::getOMPCancelDestination(OpenMPDirectiveKind Kind) {
Kind == OMPD_parallel_sections || Kind == OMPD_parallel_for || Kind == OMPD_parallel_sections || Kind == OMPD_parallel_for ||
Kind == OMPD_distribute_parallel_for || Kind == OMPD_distribute_parallel_for ||
Kind == OMPD_target_parallel_for || Kind == OMPD_target_parallel_for ||
Kind == OMPD_teams_distribute_parallel_for); Kind == OMPD_teams_distribute_parallel_for ||
Kind == OMPD_target_teams_distribute_parallel_for);
return OMPCancelStack.getExitBlock(); return OMPCancelStack.getExitBlock();
} }

View File

@ -2593,7 +2593,8 @@ static bool checkNestingOfRegions(Sema &SemaRef, DSAStackTy *Stack,
(ParentRegion == OMPD_for || ParentRegion == OMPD_parallel_for || (ParentRegion == OMPD_for || ParentRegion == OMPD_parallel_for ||
ParentRegion == OMPD_target_parallel_for || ParentRegion == OMPD_target_parallel_for ||
ParentRegion == OMPD_distribute_parallel_for || ParentRegion == OMPD_distribute_parallel_for ||
ParentRegion == OMPD_teams_distribute_parallel_for)) || ParentRegion == OMPD_teams_distribute_parallel_for ||
ParentRegion == OMPD_target_teams_distribute_parallel_for)) ||
(CancelRegion == OMPD_taskgroup && ParentRegion == OMPD_task) || (CancelRegion == OMPD_taskgroup && ParentRegion == OMPD_task) ||
(CancelRegion == OMPD_sections && (CancelRegion == OMPD_sections &&
(ParentRegion == OMPD_section || ParentRegion == OMPD_sections || (ParentRegion == OMPD_section || ParentRegion == OMPD_sections ||
@ -7324,7 +7325,8 @@ StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForDirective(
getCurFunction()->setHasBranchProtectedScope(); getCurFunction()->setHasBranchProtectedScope();
return OMPTargetTeamsDistributeParallelForDirective::Create( return OMPTargetTeamsDistributeParallelForDirective::Create(
Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B,
DSAStack->isCancelRegion());
} }
StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective( StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective(

View File

@ -2978,6 +2978,7 @@ void ASTStmtReader::VisitOMPTargetTeamsDistributeDirective(
void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForDirective( void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForDirective(
OMPTargetTeamsDistributeParallelForDirective *D) { OMPTargetTeamsDistributeParallelForDirective *D) {
VisitOMPLoopDirective(D); VisitOMPLoopDirective(D);
D->setHasCancel(Record.readInt());
} }
void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForSimdDirective( void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForSimdDirective(

View File

@ -2636,6 +2636,7 @@ void ASTStmtWriter::VisitOMPTargetTeamsDistributeDirective(
void ASTStmtWriter::VisitOMPTargetTeamsDistributeParallelForDirective( void ASTStmtWriter::VisitOMPTargetTeamsDistributeParallelForDirective(
OMPTargetTeamsDistributeParallelForDirective *D) { OMPTargetTeamsDistributeParallelForDirective *D) {
VisitOMPLoopDirective(D); VisitOMPLoopDirective(D);
Record.push_back(D->hasCancel() ? 1 : 0);
Code = serialization::STMT_OMP_TARGET_TEAMS_DISTRIBUTE_PARALLEL_FOR_DIRECTIVE; Code = serialization::STMT_OMP_TARGET_TEAMS_DISTRIBUTE_PARALLEL_FOR_DIRECTIVE;
} }

View File

@ -24,8 +24,10 @@ protected:
public: public:
S7(typename T::type v) : a(v) { S7(typename T::type v) : a(v) {
#pragma omp target teams distribute parallel for private(a) private(this->a) private(T::a) #pragma omp target teams distribute parallel for private(a) private(this->a) private(T::a)
for (int k = 0; k < a.a; ++k) for (int k = 0; k < a.a; ++k) {
++this->a.a; ++this->a.a;
#pragma omp cancel for
}
} }
S7 &operator=(S7 &s) { S7 &operator=(S7 &s) {
#pragma omp target teams distribute parallel for private(a) private(this->a) #pragma omp target teams distribute parallel for private(a) private(this->a)
@ -43,6 +45,7 @@ public:
} }
}; };
// CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a) private(T::a) // CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a) private(T::a)
// CHECK: #pragma omp cancel for
// CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a) // CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a)
// CHECK: #pragma omp target teams distribute parallel for default(none) private(b) firstprivate(argv) shared(d) reduction(+: c) reduction(max: e) num_teams(f) thread_limit(d) // CHECK: #pragma omp target teams distribute parallel for default(none) private(b) firstprivate(argv) shared(d) reduction(+: c) reduction(max: e) num_teams(f) thread_limit(d)