[flang][openacc] Lower update directive

This patch upstream the lowering of Update directive that was initially done in
https://github.com/flang-compiler/f18-llvm-project/pull/528

Reviewed By: schweitz

Differential Revision: https://reviews.llvm.org/D90472
This commit is contained in:
Valentin Clement 2020-11-04 15:47:40 -05:00 committed by clementval
parent e6cd3eff17
commit b45ea4451a
1 changed files with 107 additions and 1 deletions

View File

@ -725,6 +725,112 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
exitDataOp.finalizeAttr(firOpBuilder.getUnitAttr());
}
static void
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClauseList &accClauseList) {
mlir::Value ifCond, async, waitDevnum;
SmallVector<Value, 2> hostOperands, deviceOperands, waitOperands,
deviceTypeOperands;
// Async and wait clause have optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
bool addAsyncAttr = false;
bool addWaitAttr = false;
bool addIfPresentAttr = false;
auto &firOpBuilder = converter.getFirOpBuilder();
auto currentLocation = converter.getCurrentLocation();
// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
for (const auto &clause : accClauseList.v) {
if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
mlir::Value cond = fir::getBase(
converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v)));
ifCond = firOpBuilder.createConvert(currentLocation,
firOpBuilder.getI1Type(), cond);
} else if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue)));
} else {
addAsyncAttr = true;
}
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
mlir::Value v = fir::getBase(
converter.genExprValue(*Fortran::semantics::GetExpr(value)));
waitOperands.push_back(v);
}
const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
if (waitDevnumValue)
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue)));
} else {
addWaitAttr = true;
}
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
const auto &deviceTypeValue = deviceTypeClause->v;
if (deviceTypeValue) {
for (const auto &scalarIntExpr : *deviceTypeValue) {
mlir::Value expr = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(scalarIntExpr)));
deviceTypeOperands.push_back(expr);
}
} else {
// * was passed as value and will be represented as a -1 constant
// integer.
mlir::Value star = firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getIntegerType(32), /* STAR */ -1);
deviceTypeOperands.push_back(star);
}
} else if (const auto *hostClause =
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
genObjectList(hostClause->v, converter, hostOperands);
} else if (const auto *deviceClause =
std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
genObjectList(deviceClause->v, converter, deviceOperands);
}
}
// Prepare the operand segement size attribute and the operands value range.
SmallVector<mlir::Value, 14> operands;
SmallVector<int32_t, 7> operandSegments;
addOperand(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, deviceTypeOperands);
addOperand(operands, operandSegments, ifCond);
addOperands(operands, operandSegments, hostOperands);
addOperands(operands, operandSegments, deviceOperands);
auto updateOp = createSimpleOp<mlir::acc::UpdateOp>(
firOpBuilder, currentLocation, operands, operandSegments);
if (addAsyncAttr)
updateOp.asyncAttr(firOpBuilder.getUnitAttr());
if (addWaitAttr)
updateOp.waitAttr(firOpBuilder.getUnitAttr());
if (addIfPresentAttr)
updateOp.ifPresentAttr(firOpBuilder.getUnitAttr());
}
static void
genACC(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
@ -745,7 +851,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
TODO("OpenACC set directive not lowered yet!");
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
TODO("OpenACC update directive not lowered yet!");
genACCUpdateOp(converter, accClauseList);
}
}