[mlir][nfc] Move `getInnermostParallelLoops` to SCF/Transforms/Utils.h.

This commit is contained in:
Alexander Belyaev 2021-01-26 16:59:19 +01:00
parent feb0b4ec0a
commit 80966447a2
3 changed files with 31 additions and 24 deletions

View File

@ -17,6 +17,7 @@
namespace mlir {
class FuncOp;
class Operation;
class OpBuilder;
class ValueRange;
@ -57,5 +58,11 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
/// region is inlined into a new FuncOp that is captured by the pointer.
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
/// Get a list of innermost parallel loops contained in `rootOp`. Innermost parallel
/// loops are those that do not contain further parallel loops themselves.
bool getInnermostParallelLoops(Operation *rootOp,
SmallVectorImpl<scf::ParallelOp> &result);
} // end namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_H_

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
using namespace mlir;
@ -126,29 +127,6 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
op.erase();
}
/// Get a list of most nested parallel loops.
static bool getInnermostPloops(Operation *rootOp,
SmallVectorImpl<ParallelOp> &result) {
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
bool rootEnclosesPloops = false;
for (Region &region : rootOp->getRegions()) {
for (Block &block : region.getBlocks()) {
for (Operation &op : block) {
bool enclosesPloops = getInnermostPloops(&op, result);
rootEnclosesPloops |= enclosesPloops;
if (auto ploop = dyn_cast<ParallelOp>(op)) {
rootEnclosesPloops = true;
// Collect ploop if it is an innermost one.
if (!enclosesPloops)
result.push_back(ploop);
}
}
}
}
return rootEnclosesPloops;
}
namespace {
struct ParallelLoopTiling
: public SCFParallelLoopTilingBase<ParallelLoopTiling> {
@ -159,7 +137,7 @@ struct ParallelLoopTiling
void runOnFunction() override {
SmallVector<ParallelOp, 2> innermostPloops;
getInnermostPloops(getFunction().getOperation(), innermostPloops);
getInnermostParallelLoops(getFunction().getOperation(), innermostPloops);
for (ParallelOp ploop : innermostPloops) {
// FIXME: Add reduction support.
if (ploop.getNumReductions() == 0)

View File

@ -123,3 +123,25 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
if (elseFn && !ifOp.elseRegion().empty())
*elseFn = outline(ifOp.elseRegion(), elseFnName);
}
bool mlir::getInnermostParallelLoops(Operation *rootOp,
SmallVectorImpl<scf::ParallelOp> &result) {
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
bool rootEnclosesPloops = false;
for (Region &region : rootOp->getRegions()) {
for (Block &block : region.getBlocks()) {
for (Operation &op : block) {
bool enclosesPloops = getInnermostParallelLoops(&op, result);
rootEnclosesPloops |= enclosesPloops;
if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
rootEnclosesPloops = true;
// Collect parallel loop if it is an innermost one.
if (!enclosesPloops)
result.push_back(ploop);
}
}
}
}
return rootEnclosesPloops;
}