[MLIR][SCF] Find all innermost loops for parallel loop tiling

Overcome the assumption that parallel loops are only nested in other parallel
loops.

Differential Revision: https://reviews.llvm.org/D92188
This commit is contained in:
Frederik Gossen 2020-11-27 10:06:29 +01:00
parent 5535696c38
commit 6484567f14
2 changed files with 56 additions and 22 deletions

View File

@ -22,15 +22,15 @@ using namespace mlir::scf;
/// Tile a parallel loop of the form
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
/// step (%arg4, %arg5)
/// step (%arg4, %arg5)
///
/// into
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
/// step (%arg4*tileSize[0],
/// %arg5*tileSize[1])
/// step (%arg4*tileSize[0],
/// %arg5*tileSize[1])
/// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
/// min(%arg5*tileSize[1], %arg3-%i1))
/// step (%arg4, %arg5)
/// min(%arg5*tileSize[1], %arg3-%i1))
/// step (%arg4, %arg5)
///
/// where the uses of %i0 and %i1 in the loop body are replaced by
/// %i0 + j0 and %i1 + %j1.
@ -126,17 +126,27 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
op.erase();
}
/// Get a list of most nested parallel loops. Assumes that ParallelOps are
/// only directly nested.
static bool getInnermostNestedLoops(Block *block,
SmallVectorImpl<ParallelOp> &loops) {
bool hasInnerLoop = false;
for (auto parallelOp : block->getOps<ParallelOp>()) {
hasInnerLoop = true;
if (!getInnermostNestedLoops(parallelOp.getBody(), loops))
loops.push_back(parallelOp);
/// Get a list of most nested parallel loops.
static bool getInnermostPloops(Operation *rootOp,
SmallVectorImpl<ParallelOp> &result) {
assert(rootOp != nullptr && "Root operation must not be a nullptr.");
bool rootEnclosesPloops = false;
for (Region &region : rootOp->getRegions()) {
for (Block &block : region.getBlocks()) {
for (Operation &op : block) {
bool enclosesPloops = getInnermostPloops(&op, result);
rootEnclosesPloops |= enclosesPloops;
if (auto ploop = dyn_cast<ParallelOp>(op)) {
rootEnclosesPloops = true;
// Collect ploop if it is an innermost one.
if (!enclosesPloops)
result.push_back(ploop);
}
}
}
}
return hasInnerLoop;
return rootEnclosesPloops;
}
namespace {
@ -148,14 +158,12 @@ struct ParallelLoopTiling
}
void runOnFunction() override {
SmallVector<ParallelOp, 2> mostNestedParallelOps;
for (Block &block : getFunction()) {
getInnermostNestedLoops(&block, mostNestedParallelOps);
}
for (ParallelOp pLoop : mostNestedParallelOps) {
SmallVector<ParallelOp, 2> innermostPloops;
getInnermostPloops(getFunction().getOperation(), innermostPloops);
for (ParallelOp ploop : innermostPloops) {
// FIXME: Add reduction support.
if (pLoop.getNumReductions() == 0)
tileParallelLoop(pLoop, tileSizes);
if (ploop.getNumReductions() == 0)
tileParallelLoop(ploop, tileSizes);
}
}
};

View File

@ -112,3 +112,29 @@ func @tile_nested_innermost() {
// CHECK: }
// CHECK: return
// CHECK: }
// -----
func @tile_nested_in_non_ploop() {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
scf.for %i = %c0 to %c2 step %c1 {
scf.for %j = %c0 to %c2 step %c1 {
scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
}
}
}
return
}
// CHECK-LABEL: func @tile_nested_in_non_ploop
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.parallel
// CHECK: scf.parallel
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }