[NFC][Flang][OpenMP] Refactor OpenMP.cpp::genOpenMPReduction

This patch serves two main purposes:
Firstly, to split some of the logic into a seperate method
to try and improve readability

On top of this, it aims to make creating the reductions more generic.
That way, subsequent patches adding reductions shouldn't need
to add a significant amount of extra logic checks, such as checking
for specific operators.

Reviewed By: awarzynski

Differential Revision: https://reviews.llvm.org/D131161
This commit is contained in:
Dylan Fleming 2022-08-08 14:04:45 +00:00
parent ba9dc5f577
commit 5784199dd1
2 changed files with 65 additions and 35 deletions

View File

@ -15,6 +15,15 @@
#include <cinttypes>
namespace mlir {
class Value;
class Operation;
} // namespace mlir
namespace fir {
class FirOpBuilder;
} // namespace fir
namespace Fortran {
namespace parser {
struct OpenMPConstruct;
@ -41,6 +50,11 @@ void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genOpenMPReduction(AbstractConverter &,
const Fortran::parser::OmpClauseList &clauseList);
void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
mlir::Value);
mlir::Operation *getReductionInChain(mlir::Value, mlir::Value);
} // namespace lower
} // namespace Fortran

View File

@ -1633,42 +1633,19 @@ void Fortran::lower::genOpenMPReduction(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const auto *symbol{name->symbol}) {
mlir::Value symVal = converter.getSymbolAddress(*symbol);
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
if (!redType.isIntOrIndex())
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
mlir::Type reductionType =
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
if (!reductionType.isIntOrIndex())
continue;
for (mlir::OpOperand &use1 : symVal.getUses()) {
if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) {
mlir::Value loadVal = load.getRes();
for (mlir::OpOperand &use2 : loadVal.getUses()) {
if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>(
use2.getOwner())) {
mlir::Value addRes = add.getResult();
for (mlir::OpOperand &use3 : addRes.getUses()) {
if (auto store =
mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) {
if (store.getMemref() == symVal) {
// Chain found! Now replace load->reduction->store
// with the OpenMP reduction operation.
mlir::OpBuilder::InsertPoint insertPtDel =
firOpBuilder.saveInsertionPoint();
firOpBuilder.setInsertionPoint(add);
if (add.getLhs() == loadVal) {
firOpBuilder.create<mlir::omp::ReductionOp>(
add.getLoc(), add.getRhs(), symVal);
} else {
firOpBuilder.create<mlir::omp::ReductionOp>(
add.getLoc(), add.getLhs(), symVal);
}
store.erase();
add.erase();
load.erase();
firOpBuilder.restoreInsertionPoint(insertPtDel);
}
}
}
}
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
if (auto loadOp =
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
if (auto reductionOp = getReductionInChain(reductionVal, loadVal)) {
updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal);
}
}
}
@ -1679,3 +1656,42 @@ void Fortran::lower::genOpenMPReduction(
}
}
}
// Checks whether loadVal is used in an operation,
// the result of which is then stored into reductionVal.
// If yes, then the operation corresponding to the reduction is returned.
// loadVal is assumed to be the value of a load operation
// reductionVal is the results of an OpenMP reduction operation.
mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value reductionVal,
mlir::Value loadVal) {
for (mlir::OpOperand &loadUse : loadVal.getUses()) {
if (auto reductionOp = loadUse.getOwner()) {
for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
if (auto store =
mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
if (store.getMemref() == reductionVal) {
store.erase();
return reductionOp;
}
}
}
}
}
return nullptr;
}
void Fortran::lower::updateReduction(mlir::Operation *op,
fir::FirOpBuilder &firOpBuilder,
mlir::Value loadVal, mlir::Value reductionVal) {
mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
firOpBuilder.setInsertionPoint(op);
if (op->getOperand(0) == loadVal)
firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(1),
reductionVal);
else
firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(0),
reductionVal);
firOpBuilder.restoreInsertionPoint(insertPtDel);
}