forked from OSchip/llvm-project
[mlir][nfc] Move `getInnermostParallelLoops` to SCF/Transforms/Utils.h.
This commit is contained in:
parent
feb0b4ec0a
commit
80966447a2
|
@ -17,6 +17,7 @@
|
|||
|
||||
namespace mlir {
|
||||
class FuncOp;
|
||||
class Operation;
|
||||
class OpBuilder;
|
||||
class ValueRange;
|
||||
|
||||
|
@ -57,5 +58,11 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
|
|||
/// region is inlined into a new FuncOp that is captured by the pointer.
|
||||
void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
||||
StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
|
||||
|
||||
/// Get a list of innermost parallel loops contained in `rootOp`. Innermost parallel
|
||||
/// loops are those that do not contain further parallel loops themselves.
|
||||
bool getInnermostParallelLoops(Operation *rootOp,
|
||||
SmallVectorImpl<scf::ParallelOp> &result);
|
||||
|
||||
} // end namespace mlir
|
||||
#endif // MLIR_DIALECT_SCF_UTILS_H_
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -126,29 +127,6 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
|
|||
op.erase();
|
||||
}
|
||||
|
||||
/// Get a list of most nested parallel loops.
|
||||
static bool getInnermostPloops(Operation *rootOp,
|
||||
SmallVectorImpl<ParallelOp> &result) {
|
||||
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
|
||||
bool rootEnclosesPloops = false;
|
||||
for (Region ®ion : rootOp->getRegions()) {
|
||||
for (Block &block : region.getBlocks()) {
|
||||
for (Operation &op : block) {
|
||||
bool enclosesPloops = getInnermostPloops(&op, result);
|
||||
rootEnclosesPloops |= enclosesPloops;
|
||||
if (auto ploop = dyn_cast<ParallelOp>(op)) {
|
||||
rootEnclosesPloops = true;
|
||||
|
||||
// Collect ploop if it is an innermost one.
|
||||
if (!enclosesPloops)
|
||||
result.push_back(ploop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return rootEnclosesPloops;
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ParallelLoopTiling
|
||||
: public SCFParallelLoopTilingBase<ParallelLoopTiling> {
|
||||
|
@ -159,7 +137,7 @@ struct ParallelLoopTiling
|
|||
|
||||
void runOnFunction() override {
|
||||
SmallVector<ParallelOp, 2> innermostPloops;
|
||||
getInnermostPloops(getFunction().getOperation(), innermostPloops);
|
||||
getInnermostParallelLoops(getFunction().getOperation(), innermostPloops);
|
||||
for (ParallelOp ploop : innermostPloops) {
|
||||
// FIXME: Add reduction support.
|
||||
if (ploop.getNumReductions() == 0)
|
||||
|
|
|
@ -123,3 +123,25 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
|||
if (elseFn && !ifOp.elseRegion().empty())
|
||||
*elseFn = outline(ifOp.elseRegion(), elseFnName);
|
||||
}
|
||||
|
||||
bool mlir::getInnermostParallelLoops(Operation *rootOp,
|
||||
SmallVectorImpl<scf::ParallelOp> &result) {
|
||||
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
|
||||
bool rootEnclosesPloops = false;
|
||||
for (Region ®ion : rootOp->getRegions()) {
|
||||
for (Block &block : region.getBlocks()) {
|
||||
for (Operation &op : block) {
|
||||
bool enclosesPloops = getInnermostParallelLoops(&op, result);
|
||||
rootEnclosesPloops |= enclosesPloops;
|
||||
if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
|
||||
rootEnclosesPloops = true;
|
||||
|
||||
// Collect parallel loop if it is an innermost one.
|
||||
if (!enclosesPloops)
|
||||
result.push_back(ploop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return rootEnclosesPloops;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue