From d74d060d6d50993aee1da5b3748e029365fd1bca Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Mon, 13 Oct 2014 06:02:40 +0000 Subject: [PATCH] [OPENMP] Codegen for 'if' clause in 'parallel' directive. Adds codegen for 'if' clause. Currently only for 'if' clause used with the 'parallel' directive. If condition evaluates to true, the code executes parallel version of the code by calling __kmpc_fork_call(loc, 1, microtask, captured_struct/*context*/), where loc - debug location, 1 - number of additional parameters after "microtask" argument, microtask - is outlined finction for the code associated with the 'parallel' directive, captured_struct - list of variables captured in this outlined function. If condition evaluates to false, the code executes serial version of the code by executing the following code: global_thread_id.addr = alloca i32 store i32 global_thread_id, global_thread_id.addr zero.addr = alloca i32 store i32 0, zero.addr kmpc_serialized_parallel(loc, global_thread_id); microtask(global_thread_id.addr, zero.addr, captured_struct/*context*/); kmpc_end_serialized_parallel(loc, global_thread_id); Where loc - debug location, global_thread_id - global thread id, returned by __kmpc_global_thread_num() call or passed as a first parameter in microtask() call, global_thread_id.addr - address of the variable, where stored global_thread_id value, zero.addr - implicit bound thread id (should be set to 0 for serial call), microtask() and captured_struct are the same as in parallel call. Also this patch checks if the condition is constant and if it is constant it evaluates its value and then generates either parallel version of the code (if the condition evaluates to true), or the serial version of the code (if the condition evaluates to false). Differential Revision: http://reviews.llvm.org/D4716 llvm-svn: 219597 --- clang/include/clang/AST/StmtOpenMP.h | 7 ++ clang/lib/AST/Stmt.cpp | 15 +++ clang/lib/CodeGen/CGOpenMPRuntime.cpp | 73 ++++++++++++- clang/lib/CodeGen/CGOpenMPRuntime.h | 27 ++++- clang/lib/CodeGen/CGStmtOpenMP.cpp | 62 ++++++++++- clang/test/OpenMP/parallel_if_codegen.cpp | 124 ++++++++++++++++++++++ 6 files changed, 301 insertions(+), 7 deletions(-) create mode 100644 clang/test/OpenMP/parallel_if_codegen.cpp diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index 6a6abb9b9951..a031d8bc4bf1 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -128,6 +128,13 @@ public: operator bool() { return Current != End; } }; + /// \brief Gets a single clause of the specified kind \a K associated with the + /// current directive iff there is only one clause of this kind (and assertion + /// is fired if there is more than one clause is associated with the + /// directive). Returns nullptr if no clause of kind \a K is associated with + /// the directive. + const OMPClause *getSingleClause(OpenMPClauseKind K) const; + /// \brief Returns starting location of directive kind. SourceLocation getLocStart() const { return StartLoc; } /// \brief Returns ending location of directive. diff --git a/clang/lib/AST/Stmt.cpp b/clang/lib/AST/Stmt.cpp index d3047faebbe4..ae381361c29f 100644 --- a/clang/lib/AST/Stmt.cpp +++ b/clang/lib/AST/Stmt.cpp @@ -1434,6 +1434,21 @@ OMPFlushClause *OMPFlushClause::CreateEmpty(const ASTContext &C, unsigned N) { return new (Mem) OMPFlushClause(N); } +const OMPClause * +OMPExecutableDirective::getSingleClause(OpenMPClauseKind K) const { + auto ClauseFilter = + [=](const OMPClause *C) -> bool { return C->getClauseKind() == K; }; + OMPExecutableDirective::filtered_clause_iterator I( + clauses(), ClauseFilter); + + if (I) { + auto *Clause = *I; + assert(!++I && "There are at least 2 clauses of the specified kind"); + return Clause; + } + return nullptr; +} + OMPParallelDirective *OMPParallelDirective::Create( const ASTContext &C, SourceLocation StartLoc, diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index b3e5c1123ce2..53d4b6cb4048 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -16,6 +16,7 @@ #include "clang/AST/StmtOpenMP.h" #include "clang/AST/Decl.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Value.h" @@ -253,7 +254,7 @@ CGOpenMPRuntime::CreateRuntimeFunction(OpenMPRTLFunction Function) { llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty, getKmpc_MicroPointerTy()}; llvm::FunctionType *FnTy = - llvm::FunctionType::get(CGM.VoidTy, TypeParams, true); + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ true); RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_fork_call"); break; } @@ -261,7 +262,7 @@ CGOpenMPRuntime::CreateRuntimeFunction(OpenMPRTLFunction Function) { // Build kmp_int32 __kmpc_global_thread_num(ident_t *loc); llvm::Type *TypeParams[] = {getIdentTyPointerTy()}; llvm::FunctionType *FnTy = - llvm::FunctionType::get(CGM.Int32Ty, TypeParams, false); + llvm::FunctionType::get(CGM.Int32Ty, TypeParams, /*isVarArg*/ false); RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_global_thread_num"); break; } @@ -295,6 +296,24 @@ CGOpenMPRuntime::CreateRuntimeFunction(OpenMPRTLFunction Function) { RTLFn = CGM.CreateRuntimeFunction(FnTy, /*Name*/ "__kmpc_barrier"); break; } + case OMPRTL__kmpc_serialized_parallel: { + // Build void __kmpc_serialized_parallel(ident_t *loc, kmp_int32 + // global_tid); + llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_serialized_parallel"); + break; + } + case OMPRTL__kmpc_end_serialized_parallel: { + // Build void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32 + // global_tid); + llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_serialized_parallel"); + break; + } } return RTLFn; } @@ -314,6 +333,56 @@ void CGOpenMPRuntime::EmitOMPParallelCall(CodeGenFunction &CGF, CGF.EmitRuntimeCall(RTLFn, Args); } +void CGOpenMPRuntime::EmitOMPSerialCall(CodeGenFunction &CGF, + SourceLocation Loc, + llvm::Value *OutlinedFn, + llvm::Value *CapturedStruct) { + auto ThreadID = GetOpenMPThreadID(CGF, Loc); + // Build calls: + // __kmpc_serialized_parallel(&Loc, GTid); + llvm::Value *SerArgs[] = {EmitOpenMPUpdateLocation(CGF, Loc), ThreadID}; + auto RTLFn = + CreateRuntimeFunction(CGOpenMPRuntime::OMPRTL__kmpc_serialized_parallel); + CGF.EmitRuntimeCall(RTLFn, SerArgs); + + // OutlinedFn(>id, &zero, CapturedStruct); + auto ThreadIDAddr = EmitThreadIDAddress(CGF, Loc); + auto Int32Ty = + CGF.getContext().getIntTypeForBitwidth(/*DestWidth*/ 32, /*Signed*/ true); + auto ZeroAddr = CGF.CreateMemTemp(Int32Ty, /*Name*/ ".zero.addr"); + CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C*/ 0)); + llvm::Value *OutlinedFnArgs[] = {ThreadIDAddr, ZeroAddr, CapturedStruct}; + CGF.EmitCallOrInvoke(OutlinedFn, OutlinedFnArgs); + + // __kmpc_end_serialized_parallel(&Loc, GTid); + llvm::Value *EndSerArgs[] = {EmitOpenMPUpdateLocation(CGF, Loc), ThreadID}; + RTLFn = CreateRuntimeFunction( + CGOpenMPRuntime::OMPRTL__kmpc_end_serialized_parallel); + CGF.EmitRuntimeCall(RTLFn, EndSerArgs); +} + +// If we’re inside an (outlined) parallel region, use the region info’s +// thread-ID variable (it is passed in a first argument of the outlined function +// as "kmp_int32 *gtid"). Otherwise, if we're not inside parallel region, but in +// regular serial code region, get thread ID by calling kmp_int32 +// kmpc_global_thread_num(ident_t *loc), stash this thread ID in a temporary and +// return the address of that temp. +llvm::Value *CGOpenMPRuntime::EmitThreadIDAddress(CodeGenFunction &CGF, + SourceLocation Loc) { + if (auto OMPRegionInfo = + dyn_cast_or_null(CGF.CapturedStmtInfo)) + return CGF.EmitLoadOfLValue(OMPRegionInfo->getThreadIDVariableLValue(CGF), + SourceLocation()).getScalarVal(); + auto ThreadID = GetOpenMPThreadID(CGF, Loc); + auto Int32Ty = + CGF.getContext().getIntTypeForBitwidth(/*DestWidth*/ 32, /*Signed*/ true); + auto ThreadIDTemp = CGF.CreateMemTemp(Int32Ty, /*Name*/ ".threadid_temp."); + CGF.EmitStoreOfScalar(ThreadID, + CGF.MakeNaturalAlignAddrLValue(ThreadIDTemp, Int32Ty)); + + return ThreadIDTemp; +} + llvm::Value *CGOpenMPRuntime::GetCriticalRegionLock(StringRef CriticalName) { SmallString<256> Buffer; llvm::raw_svector_ostream Out(Buffer); diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h index ce822ea7d1b8..04378821d7a6 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.h +++ b/clang/lib/CodeGen/CGOpenMPRuntime.h @@ -74,7 +74,13 @@ public: // kmp_critical_name *crit); OMPRTL__kmpc_end_critical, // Call to void __kmpc_barrier(ident_t *loc, kmp_int32 global_tid); - OMPRTL__kmpc_barrier + OMPRTL__kmpc_barrier, + // Call to void __kmpc_serialized_parallel(ident_t *loc, kmp_int32 + // global_tid); + OMPRTL__kmpc_serialized_parallel, + // Call to void __kmpc_end_serialized_parallel(ident_t *loc, kmp_int32 + // global_tid); + OMPRTL__kmpc_end_serialized_parallel }; private: @@ -156,10 +162,10 @@ private: EmitOpenMPUpdateLocation(CodeGenFunction &CGF, SourceLocation Loc, OpenMPLocationFlags Flags = OMP_IDENT_KMPC); - /// \brief Returns pointer to ident_t type; + /// \brief Returns pointer to ident_t type. llvm::Type *getIdentTyPointerTy(); - /// \brief Returns pointer to kmpc_micro type; + /// \brief Returns pointer to kmpc_micro type. llvm::Type *getKmpc_MicroPointerTy(); /// \brief Returns specified OpenMP runtime function. @@ -167,6 +173,11 @@ private: /// \return Specified function. llvm::Constant *CreateRuntimeFunction(OpenMPRTLFunction Function); + /// \brief Emits address of the word in a memory where current thread id is + /// stored. + virtual llvm::Value *EmitThreadIDAddress(CodeGenFunction &CGF, + SourceLocation Loc); + /// \brief Gets thread id value for the current thread. /// llvm::Value *GetOpenMPThreadID(CodeGenFunction &CGF, SourceLocation Loc); @@ -201,6 +212,16 @@ public: llvm::Value *OutlinedFn, llvm::Value *CapturedStruct); + /// \brief Emits code for serial call of the \a OutlinedFn with variables + /// captured in a record which address is stored in \a CapturedStruct. + /// \param OutlinedFn Outlined function to be run in serial mode. + /// \param CapturedStruct A pointer to the record with the references to + /// variables used in \a OutlinedFn function. + /// + virtual void EmitOMPSerialCall(CodeGenFunction &CGF, SourceLocation Loc, + llvm::Value *OutlinedFn, + llvm::Value *CapturedStruct); + /// \brief Returns corresponding lock object for the specified critical region /// name. If the lock object does not exist it is created, otherwise the /// reference to the existing copy is returned. diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index 2aec28f3051c..a459d07a7236 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -24,6 +24,52 @@ using namespace CodeGen; // OpenMP Directive Emission //===----------------------------------------------------------------------===// +/// \brief Emits code for OpenMP 'if' clause using specified \a CodeGen +/// function. Here is the logic: +/// if (Cond) { +/// CodeGen(true); +/// } else { +/// CodeGen(false); +/// } +static void EmitOMPIfClause(CodeGenFunction &CGF, const Expr *Cond, + const std::function &CodeGen) { + CodeGenFunction::LexicalScope ConditionScope(CGF, Cond->getSourceRange()); + + // If the condition constant folds and can be elided, try to avoid emitting + // the condition and the dead arm of the if/else. + bool CondConstant; + if (CGF.ConstantFoldsToSimpleInteger(Cond, CondConstant)) { + CodeGen(CondConstant); + return; + } + + // Otherwise, the condition did not fold, or we couldn't elide it. Just + // emit the conditional branch. + auto ThenBlock = CGF.createBasicBlock(/*name*/ "omp_if.then"); + auto ElseBlock = CGF.createBasicBlock(/*name*/ "omp_if.else"); + auto ContBlock = CGF.createBasicBlock(/*name*/ "omp_if.end"); + CGF.EmitBranchOnBoolExpr(Cond, ThenBlock, ElseBlock, /*TrueCount*/ 0); + + // Emit the 'then' code. + CGF.EmitBlock(ThenBlock); + CodeGen(/*ThenBlock*/ true); + CGF.EmitBranch(ContBlock); + // Emit the 'else' code if present. + { + // There is no need to emit line number for unconditional branch. + SuppressDebugLocation SDL(CGF.Builder); + CGF.EmitBlock(ElseBlock); + } + CodeGen(/*ThenBlock*/ false); + { + // There is no need to emit line number for unconditional branch. + SuppressDebugLocation SDL(CGF.Builder); + CGF.EmitBranch(ContBlock); + } + // Emit the continuation block for code after the if. + CGF.EmitBlock(ContBlock, /*IsFinished*/ true); +} + void CodeGenFunction::EmitOMPAggregateAssign(LValue OriginalAddr, llvm::Value *PrivateAddr, const Expr *AssignExpr, @@ -142,8 +188,20 @@ void CodeGenFunction::EmitOMPParallelDirective(const OMPParallelDirective &S) { auto CapturedStruct = GenerateCapturedStmtArgument(*CS); auto OutlinedFn = CGM.getOpenMPRuntime().EmitOpenMPOutlinedFunction( S, *CS->getCapturedDecl()->param_begin()); - CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(), OutlinedFn, - CapturedStruct); + if (auto C = S.getSingleClause(/*K*/ OMPC_if)) { + auto Cond = cast(C)->getCondition(); + EmitOMPIfClause(*this, Cond, [&](bool ThenBlock) { + if (ThenBlock) + CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(), + OutlinedFn, CapturedStruct); + else + CGM.getOpenMPRuntime().EmitOMPSerialCall(*this, S.getLocStart(), + OutlinedFn, CapturedStruct); + }); + } else { + CGM.getOpenMPRuntime().EmitOMPParallelCall(*this, S.getLocStart(), + OutlinedFn, CapturedStruct); + } } void CodeGenFunction::EmitOMPLoopBody(const OMPLoopDirective &S, diff --git a/clang/test/OpenMP/parallel_if_codegen.cpp b/clang/test/OpenMP/parallel_if_codegen.cpp new file mode 100644 index 000000000000..54eedbefe509 --- /dev/null +++ b/clang/test/OpenMP/parallel_if_codegen.cpp @@ -0,0 +1,124 @@ +// RUN: %clang_cc1 -verify -fopenmp=libiomp5 -x c++ -triple %itanium_abi_triple -emit-llvm %s -o - | FileCheck %s +// RUN: %clang_cc1 -fopenmp=libiomp5 -x c++ -std=c++11 -triple %itanium_abi_triple -emit-pch -o %t %s +// RUN: %clang_cc1 -fopenmp=libiomp5 -x c++ -triple %itanium_abi_triple -std=c++11 -include-pch %t -verify %s -emit-llvm -o - | FileCheck --check-prefix=CHECK %s +// expected-no-diagnostics +#ifndef HEADER +#define HEADER + +void fn1(); +void fn2(); +void fn3(); +void fn4(); +void fn5(); +void fn6(); + +int Arg; + +// CHECK-LABEL: define void @{{.+}}gtid_test +void gtid_test() { +// CHECK: call void {{.+}}* @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, {{.+}}* [[GTID_TEST_REGION1:@.+]] to void +#pragma omp parallel +#pragma omp parallel if (false) + gtid_test(); +// CHECK: ret void +} + +// CHECK: define internal void [[GTID_TEST_REGION1]](i{{.+}}* [[GTID_PARAM:%.+]], i +// CHECK: store i{{[0-9]+}}* [[GTID_PARAM]], i{{[0-9]+}}** [[GTID_ADDR_REF:%.+]], +// CHECK: [[GTID_ADDR:%.+]] = load i{{[0-9]+}}** [[GTID_ADDR_REF]] +// CHECK: [[GTID:%.+]] = load i{{[0-9]+}}* [[GTID_ADDR]] +// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i{{.+}} [[GTID]]) +// CHECK: [[GTID_ADDR:%.+]] = load i{{[0-9]+}}** [[GTID_ADDR_REF]] +// CHECK: call void [[GTID_TEST_REGION2:@.+]](i{{[0-9]+}}* [[GTID_ADDR]] +// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i{{.+}} [[GTID]]) +// CHECK: ret void + +// CHECK: define internal void [[GTID_TEST_REGION2]]( +// CHECK: call void @{{.+}}gtid_test +// CHECK: ret void + +template +int tmain(T Arg) { +#pragma omp parallel if (true) + fn1(); +#pragma omp parallel if (false) + fn2(); +#pragma omp parallel if (Arg) + fn3(); + return 0; +} + +// CHECK-LABEL: define {{.*}}i{{[0-9]+}} @main() +int main() { +// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN4:@.+]] to void +#pragma omp parallel if (true) + fn4(); +// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]]) +// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]], +// CHECK: call void [[CAP_FN5:@.+]](i32* [[GTID_ADDR]], +// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]]) +#pragma omp parallel if (false) + fn5(); + +// CHECK: br i1 %{{.+}}, label %[[OMP_THEN:.+]], label %[[OMP_ELSE:.+]] +// CHECK: [[OMP_THEN]] +// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN6:@.+]] to void +// CHECK: br label %[[OMP_END:.+]] +// CHECK: [[OMP_ELSE]] +// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 %0) +// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]], +// CHECK: call void [[CAP_FN6]](i32* [[GTID_ADDR]], +// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]]) +// CHECK: br label %[[OMP_END]] +// CHECK: [[OMP_END]] +#pragma omp parallel if (Arg) + fn6(); + // CHECK: = call i{{.+}} @{{.+}}tmain + return tmain(Arg); +} + +// CHECK: define internal void [[CAP_FN4]] +// CHECK: call void @{{.+}}fn4 +// CHECK: ret void + +// CHECK: define internal void [[CAP_FN5]] +// CHECK: call void @{{.+}}fn5 +// CHECK: ret void + +// CHECK: define internal void [[CAP_FN6]] +// CHECK: call void @{{.+}}fn6 +// CHECK: ret void + +// CHECK-LABEL: define {{.+}} @{{.+}}tmain +// CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN1:@.+]] to void +// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]]) +// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]], +// CHECK: call void [[CAP_FN2:@.+]](i32* [[GTID_ADDR]], +// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]]) +// CHECK: br i1 %{{.+}}, label %[[OMP_THEN:.+]], label %[[OMP_ELSE:.+]] +// CHECK: [[OMP_THEN]] +// CHECK: call void {{.+}} @__kmpc_fork_call(%{{.+}}* @{{.+}}, i{{.+}} 1, void {{.+}}* [[CAP_FN3:@.+]] to void +// CHECK: br label %[[OMP_END:.+]] +// CHECK: [[OMP_ELSE]] +// CHECK: call void @__kmpc_serialized_parallel(%{{.+}}* @{{.+}}, i32 %0) +// CHECK: store i32 [[GTID]], i32* [[GTID_ADDR:%.+]], +// CHECK: call void [[CAP_FN3]](i32* [[GTID_ADDR]], +// CHECK: call void @__kmpc_end_serialized_parallel(%{{.+}}* @{{.+}}, i32 [[GTID]]) +// CHECK: br label %[[OMP_END]] +// CHECK: [[OMP_END]] + +// CHECK: define internal void [[CAP_FN1]] +// CHECK: call void @{{.+}}fn1 +// CHECK: ret void + +// CHECK: define internal void [[CAP_FN2]] +// CHECK: call void @{{.+}}fn2 +// CHECK: ret void + +// CHECK: define internal void [[CAP_FN3]] +// CHECK: call void @{{.+}}fn3 +// CHECK: ret void + +#endif