[flang][OpenMP] Lowering critical construct

This patch adds translation from PFT to FIR for critical construct.

This is part of the upstreaming effort from the fir-dev branch in [1].
[1] https://github.com/flang-compiler/f18-llvm-project

Co-authored-by: kiranchandramohan <kiranchandramohan@gmail.com>

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D122218
This commit is contained in:
Shraiysh Vaishay 2022-03-22 15:17:52 +05:30
parent 5a65f0b4d9
commit ebec5e5c8f
3 changed files with 100 additions and 27 deletions

View File

@ -202,6 +202,49 @@ genOMP(Fortran::lower::AbstractConverter &converter,
}
}
static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
std::string name;
const Fortran::parser::OmpCriticalDirective &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
name =
std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
}
uint64_t hint = 0;
const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
for (const Fortran::parser::OmpClause &clause : clauseList.v)
if (auto hintClause =
std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) {
const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
hint = *Fortran::evaluate::ToInt64(*expr);
break;
}
mlir::omp::CriticalOp criticalOp = [&]() {
if (name.empty()) {
return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation,
FlatSymbolRefAttr());
} else {
mlir::ModuleOp module = firOpBuilder.getModule();
mlir::OpBuilder modBuilder(module.getBodyRegion());
auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
if (!global)
global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
currentLocation, name, hint);
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr::get(
firOpBuilder.getContext(), global.sym_name()));
}
}();
createBodyOfOp<omp::CriticalOp>(criticalOp, firOpBuilder, currentLocation);
}
void Fortran::lower::genOpenMPConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
@ -239,7 +282,7 @@ void Fortran::lower::genOpenMPConstruct(
},
[&](const Fortran::parser::OpenMPCriticalConstruct
&criticalConstruct) {
TODO(converter.getCurrentLocation(), "OpenMPCriticalConstruct");
genOMP(converter, eval, criticalConstruct);
},
},
ompConstruct.u);

View File

@ -293,8 +293,6 @@ public:
bool Pre(const parser::OpenMPBlockConstruct &);
void Post(const parser::OpenMPBlockConstruct &);
bool Pre(const parser::OmpCriticalDirective &x);
bool Pre(const parser::OmpEndCriticalDirective &x);
void Post(const parser::OmpBeginBlockDirective &) {
GetContext().withinConstruct = true;
@ -313,7 +311,7 @@ public:
bool Pre(const parser::OpenMPSectionsConstruct &);
void Post(const parser::OpenMPSectionsConstruct &) { PopContext(); }
bool Pre(const parser::OpenMPCriticalConstruct &);
bool Pre(const parser::OpenMPCriticalConstruct &critical);
void Post(const parser::OpenMPCriticalConstruct &) { PopContext(); }
bool Pre(const parser::OpenMPDeclareSimdConstruct &x) {
@ -1376,25 +1374,18 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPSectionsConstruct &x) {
return true;
}
bool OmpAttributeVisitor::Pre(const parser::OmpCriticalDirective &x) {
const auto &name{std::get<std::optional<parser::Name>>(x.t)};
if (name) {
ResolveOmpName(*name, Symbol::Flag::OmpCriticalLock);
}
return true;
}
bool OmpAttributeVisitor::Pre(const parser::OmpEndCriticalDirective &x) {
const auto &name{std::get<std::optional<parser::Name>>(x.t)};
if (name) {
ResolveOmpName(*name, Symbol::Flag::OmpCriticalLock);
}
return true;
}
bool OmpAttributeVisitor::Pre(const parser::OpenMPCriticalConstruct &x) {
const auto &criticalDir{std::get<parser::OmpCriticalDirective>(x.t)};
PushContext(criticalDir.source, llvm::omp::Directive::OMPD_critical);
const auto &beginCriticalDir{std::get<parser::OmpCriticalDirective>(x.t)};
const auto &endCriticalDir{std::get<parser::OmpEndCriticalDirective>(x.t)};
PushContext(beginCriticalDir.source, llvm::omp::Directive::OMPD_critical);
if (const auto &criticalName{
std::get<std::optional<parser::Name>>(beginCriticalDir.t)}) {
ResolveOmpName(*criticalName, Symbol::Flag::OmpCriticalLock);
}
if (const auto &endCriticalName{
std::get<std::optional<parser::Name>>(endCriticalDir.t)}) {
ResolveOmpName(*endCriticalName, Symbol::Flag::OmpCriticalLock);
}
return true;
}
@ -1515,13 +1506,11 @@ void OmpAttributeVisitor::ResolveOmpName(
AddToContextObjectWithDSA(*resolvedSymbol, ompFlag);
}
}
} else if (ompFlagsRequireNewSymbol.test(ompFlag)) {
const auto pair{GetContext().scope.try_emplace(
name.source, Attrs{}, ObjectEntityDetails{})};
} else if (ompFlag == Symbol::Flag::OmpCriticalLock) {
const auto pair{
GetContext().scope.try_emplace(name.source, Attrs{}, UnknownDetails{})};
CHECK(pair.second);
name.symbol = &pair.first->second.get();
} else {
DIE("OpenMP Name resolution failed");
}
}

View File

@ -0,0 +1,41 @@
!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefix="FIRDialect"
!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefix="LLVMDialect"
!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | tco | FileCheck %s --check-prefix="LLVMIR"
subroutine omp_critical()
use omp_lib
integer :: x, y
!FIRDialect: omp.critical.declare @help hint(contended)
!LLVMDialect: omp.critical.declare @help hint(contended)
!FIRDialect: omp.critical(@help)
!LLVMDialect: omp.critical(@help)
!LLVMIR: call void @__kmpc_critical_with_hint({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var, i32 2)
!$OMP CRITICAL(help) HINT(omp_lock_hint_contended)
x = x + y
!FIRDialect: omp.terminator
!LLVMDialect: omp.terminator
!LLVMIR: call void @__kmpc_end_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var)
!$OMP END CRITICAL(help)
! Test that the same name can be used again
! Also test with the zero hint expression
!FIRDialect: omp.critical(@help)
!LLVMDialect: omp.critical(@help)
!LLVMIR: call void @__kmpc_critical_with_hint({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var, i32 2)
!$OMP CRITICAL(help) HINT(omp_lock_hint_none)
x = x - y
!FIRDialect: omp.terminator
!LLVMDialect: omp.terminator
!LLVMIR: call void @__kmpc_end_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}help.var)
!$OMP END CRITICAL(help)
!FIRDialect: omp.critical
!LLVMDialect: omp.critical
!LLVMIR: call void @__kmpc_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}_.var)
!$OMP CRITICAL
y = x + y
!FIRDialect: omp.terminator
!LLVMDialect: omp.terminator
!LLVMIR: call void @__kmpc_end_critical({{.*}}, {{.*}}, {{.*}} @{{.*}}_.var)
!$OMP END CRITICAL
end subroutine omp_critical