forked from OSchip/llvm-project
[flang[OpenACC] Lower wait directive
This patch adds lowering for the `!$acc wait` directive from the PFT to OpenACC dialect. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: PeteSteinfeld Differential Revision: https://reviews.llvm.org/D122399
This commit is contained in:
parent
67eb2f144e
commit
44b0ea44f2
|
@ -898,16 +898,16 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
|
|||
const auto &accClauseList =
|
||||
std::get<Fortran::parser::AccClauseList>(waitConstruct.t);
|
||||
|
||||
mlir::Value ifCond, waitDevnum, async;
|
||||
SmallVector<mlir::Value, 2> waitOperands;
|
||||
mlir::Value ifCond, asyncOperand, waitDevnum, async;
|
||||
SmallVector<mlir::Value> waitOperands;
|
||||
|
||||
// Async clause have optional values but can be present with
|
||||
// no value as well. When there is no value, the op has an attribute to
|
||||
// represent the clause.
|
||||
bool addAsyncAttr = false;
|
||||
|
||||
auto &firOpBuilder = converter.getFirOpBuilder();
|
||||
auto currentLocation = converter.getCurrentLocation();
|
||||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||
mlir::Location currentLocation = converter.getCurrentLocation();
|
||||
Fortran::lower::StatementContext stmtCtx;
|
||||
|
||||
if (waitArgument) { // wait has a value.
|
||||
|
@ -930,35 +930,26 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
|
|||
// Lower clauses values mapped to operands.
|
||||
// Keep track of each group of operands separatly as clauses can appear
|
||||
// more than once.
|
||||
for (const auto &clause : accClauseList.v) {
|
||||
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
|
||||
if (const auto *ifClause =
|
||||
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
|
||||
mlir::Value cond = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(ifClause->v), stmtCtx));
|
||||
ifCond = firOpBuilder.createConvert(currentLocation,
|
||||
firOpBuilder.getI1Type(), cond);
|
||||
genIfClause(converter, ifClause, ifCond, stmtCtx);
|
||||
} else if (const auto *asyncClause =
|
||||
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
|
||||
const auto &asyncClauseValue = asyncClause->v;
|
||||
if (asyncClauseValue) { // async has a value.
|
||||
async = fir::getBase(converter.genExprValue(
|
||||
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
|
||||
} else {
|
||||
addAsyncAttr = true;
|
||||
}
|
||||
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the operand segement size attribute and the operands value range.
|
||||
SmallVector<mlir::Value, 8> operands;
|
||||
SmallVector<int32_t, 4> operandSegments;
|
||||
SmallVector<mlir::Value> operands;
|
||||
SmallVector<int32_t> operandSegments;
|
||||
addOperands(operands, operandSegments, waitOperands);
|
||||
addOperand(operands, operandSegments, async);
|
||||
addOperand(operands, operandSegments, waitDevnum);
|
||||
addOperand(operands, operandSegments, ifCond);
|
||||
|
||||
auto waitOp = createSimpleOp<mlir::acc::WaitOp>(firOpBuilder, currentLocation,
|
||||
operands, operandSegments);
|
||||
mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>(
|
||||
firOpBuilder, currentLocation, operands, operandSegments);
|
||||
|
||||
if (addAsyncAttr)
|
||||
waitOp.asyncAttr(firOpBuilder.getUnitAttr());
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
! This test checks lowering of OpenACC wait directive.
|
||||
|
||||
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
|
||||
|
||||
subroutine acc_update
|
||||
integer :: async = 1
|
||||
logical :: ifCondition = .TRUE.
|
||||
|
||||
!$acc wait
|
||||
!CHECK: acc.wait{{$}}
|
||||
|
||||
!$acc wait if(.true.)
|
||||
!CHECK: [[IF1:%.*]] = arith.constant true
|
||||
!CHECK: acc.wait if([[IF1]]){{$}}
|
||||
|
||||
!$acc wait if(ifCondition)
|
||||
!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
|
||||
!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
|
||||
!CHECK: acc.wait if([[IF2]]){{$}}
|
||||
|
||||
!$acc wait(1, 2)
|
||||
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
|
||||
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
|
||||
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32){{$}}
|
||||
|
||||
!$acc wait(1) async
|
||||
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
|
||||
!CHECK: acc.wait([[WAIT3]] : i32) attributes {async}
|
||||
|
||||
!$acc wait(1) async(async)
|
||||
!CHECK: [[WAIT3:%.*]] = arith.constant 1 : i32
|
||||
!CHECK: [[ASYNC1:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
|
||||
!CHECK: acc.wait([[WAIT3]] : i32) async([[ASYNC1]] : i32){{$}}
|
||||
|
||||
!$acc wait(devnum: 3: queues: 1, 2)
|
||||
!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
|
||||
!CHECK: [[WAIT2:%.*]] = arith.constant 2 : i32
|
||||
!CHECK: [[DEVNUM:%.*]] = arith.constant 3 : i32
|
||||
!CHECK: acc.wait([[WAIT1]], [[WAIT2]] : i32, i32) wait_devnum([[DEVNUM]] : i32){{$}}
|
||||
|
||||
end subroutine acc_update
|
Loading…
Reference in New Issue