forked from OSchip/llvm-project
[NFC][Flang][OpenMP] Refactor OpenMP.cpp::genOpenMPReduction
This patch serves two main purposes: Firstly, to split some of the logic into a seperate method to try and improve readability On top of this, it aims to make creating the reductions more generic. That way, subsequent patches adding reductions shouldn't need to add a significant amount of extra logic checks, such as checking for specific operators. Reviewed By: awarzynski Differential Revision: https://reviews.llvm.org/D131161
This commit is contained in:
parent
ba9dc5f577
commit
5784199dd1
|
@ -15,6 +15,15 @@
|
|||
|
||||
#include <cinttypes>
|
||||
|
||||
namespace mlir {
|
||||
class Value;
|
||||
class Operation;
|
||||
} // namespace mlir
|
||||
|
||||
namespace fir {
|
||||
class FirOpBuilder;
|
||||
} // namespace fir
|
||||
|
||||
namespace Fortran {
|
||||
namespace parser {
|
||||
struct OpenMPConstruct;
|
||||
|
@ -41,6 +50,11 @@ void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
|
|||
void genOpenMPReduction(AbstractConverter &,
|
||||
const Fortran::parser::OmpClauseList &clauseList);
|
||||
|
||||
void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
|
||||
mlir::Value);
|
||||
|
||||
mlir::Operation *getReductionInChain(mlir::Value, mlir::Value);
|
||||
|
||||
} // namespace lower
|
||||
} // namespace Fortran
|
||||
|
||||
|
|
|
@ -1633,42 +1633,19 @@ void Fortran::lower::genOpenMPReduction(
|
|||
if (const auto *name{
|
||||
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
|
||||
if (const auto *symbol{name->symbol}) {
|
||||
mlir::Value symVal = converter.getSymbolAddress(*symbol);
|
||||
mlir::Type redType =
|
||||
symVal.getType().cast<fir::ReferenceType>().getEleTy();
|
||||
if (!redType.isIntOrIndex())
|
||||
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
|
||||
mlir::Type reductionType =
|
||||
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
|
||||
if (!reductionType.isIntOrIndex())
|
||||
continue;
|
||||
for (mlir::OpOperand &use1 : symVal.getUses()) {
|
||||
if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) {
|
||||
mlir::Value loadVal = load.getRes();
|
||||
for (mlir::OpOperand &use2 : loadVal.getUses()) {
|
||||
if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>(
|
||||
use2.getOwner())) {
|
||||
mlir::Value addRes = add.getResult();
|
||||
for (mlir::OpOperand &use3 : addRes.getUses()) {
|
||||
if (auto store =
|
||||
mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) {
|
||||
if (store.getMemref() == symVal) {
|
||||
// Chain found! Now replace load->reduction->store
|
||||
// with the OpenMP reduction operation.
|
||||
mlir::OpBuilder::InsertPoint insertPtDel =
|
||||
firOpBuilder.saveInsertionPoint();
|
||||
firOpBuilder.setInsertionPoint(add);
|
||||
if (add.getLhs() == loadVal) {
|
||||
firOpBuilder.create<mlir::omp::ReductionOp>(
|
||||
add.getLoc(), add.getRhs(), symVal);
|
||||
} else {
|
||||
firOpBuilder.create<mlir::omp::ReductionOp>(
|
||||
add.getLoc(), add.getLhs(), symVal);
|
||||
}
|
||||
store.erase();
|
||||
add.erase();
|
||||
load.erase();
|
||||
firOpBuilder.restoreInsertionPoint(insertPtDel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
|
||||
|
||||
if (auto loadOp =
|
||||
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
|
||||
mlir::Value loadVal = loadOp.getRes();
|
||||
if (auto reductionOp = getReductionInChain(reductionVal, loadVal)) {
|
||||
updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1679,3 +1656,42 @@ void Fortran::lower::genOpenMPReduction(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Checks whether loadVal is used in an operation,
|
||||
// the result of which is then stored into reductionVal.
|
||||
// If yes, then the operation corresponding to the reduction is returned.
|
||||
// loadVal is assumed to be the value of a load operation
|
||||
// reductionVal is the results of an OpenMP reduction operation.
|
||||
mlir::Operation *Fortran::lower::getReductionInChain(mlir::Value reductionVal,
|
||||
mlir::Value loadVal) {
|
||||
for (mlir::OpOperand &loadUse : loadVal.getUses()) {
|
||||
if (auto reductionOp = loadUse.getOwner()) {
|
||||
for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) {
|
||||
if (auto store =
|
||||
mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) {
|
||||
if (store.getMemref() == reductionVal) {
|
||||
store.erase();
|
||||
return reductionOp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Fortran::lower::updateReduction(mlir::Operation *op,
|
||||
fir::FirOpBuilder &firOpBuilder,
|
||||
mlir::Value loadVal, mlir::Value reductionVal) {
|
||||
mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint();
|
||||
firOpBuilder.setInsertionPoint(op);
|
||||
|
||||
if (op->getOperand(0) == loadVal)
|
||||
firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(1),
|
||||
reductionVal);
|
||||
else
|
||||
firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), op->getOperand(0),
|
||||
reductionVal);
|
||||
|
||||
firOpBuilder.restoreInsertionPoint(insertPtDel);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue