[mlir] Async: update condition for dispatching block-aligned compute function

+ compare block size with the unrollable inner dimension
+ reduce nesting in the code and simplify a bit IR building

Reviewed By: cota

Differential Revision: https://reviews.llvm.org/D120075
This commit is contained in:
Eugene Zhulenev 2022-02-17 10:22:18 -08:00
parent fc0aa8424c
commit beff16f7bd
2 changed files with 53 additions and 51 deletions

View File

@ -779,10 +779,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to
// unroll the loops. The constant `512` is arbitrary, it should depend on
// how many iterations LLVM will typically decide to unroll.
static constexpr int64_t maxIterations = 512;
static constexpr int64_t maxUnrollableIterations = 512;
// The number of inner loops with statically known number of iterations less
// than the `maxIterations` value.
// than the `maxUnrollableIterations` value.
int numUnrollableLoops = 0;
auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
@ -796,7 +796,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
numIterations[i] = tripCount * innerIterations;
// Update the number of inner loops that we can potentially unroll.
if (innerIterations > 0 && innerIterations <= maxIterations)
if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
numUnrollableLoops++;
}
@ -856,9 +856,6 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
ParallelComputeFunction notUnrollableParallelComputeFunction =
createParallelComputeFunction(op, staticBounds, 0, rewriter);
// Dispatch parallel compute function using async recursive work splitting,
// or by submitting compute task sequentially from a caller thread.
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
@ -869,42 +866,47 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// Compute the number of parallel compute blocks.
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
// Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
bool staticShouldUnroll = numUnrollableLoops > 0;
auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
// Dispatch parallel compute function without hints to unroll inner loops.
auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
ParallelComputeFunction compute =
createParallelComputeFunction(op, staticBounds, 0, rewriter);
ImplicitLocOpBuilder b(loc, nestedBuilder);
doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
blockSize, blockCount, tripCounts);
doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
b.create<scf::YieldOp>();
};
if (staticShouldUnroll) {
Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, blockSize,
b.create<arith::ConstantIndexOp>(maxIterations));
// Dispatch parallel compute function with hints for unrolling inner loops.
auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
ParallelComputeFunction compute = createParallelComputeFunction(
op, staticBounds, numUnrollableLoops, rewriter);
ParallelComputeFunction unrollableParallelComputeFunction =
createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
rewriter);
ImplicitLocOpBuilder b(loc, nestedBuilder);
// Align the block size to be a multiple of the statically known
// number of iterations in the inner loops.
Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value alignedBlockSize = b.create<arith::MulIOp>(
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
tripCounts);
b.create<scf::YieldOp>();
};
auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder b(loc, nestedBuilder);
// Align the block size to be a multiple of the statically known
// number of iterations in the inner loops.
Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value alignedBlockSize = b.create<arith::MulIOp>(
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
alignedBlockSize, blockCount, tripCounts);
b.create<scf::YieldOp>();
};
// Dispatch to block aligned compute function only if the computed block
// size is larger than the number of iterations in the unrollable inner
// loops, because otherwise it can reduce the available parallelism.
if (numUnrollableLoops > 0) {
Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, blockSize, numIters);
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
dispatchNotUnrollable);
b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
dispatchBlockAligned, dispatchDefault);
b.create<scf::YieldOp>();
} else {
dispatchNotUnrollable(b, loc);
dispatchDefault(b, loc);
}
};

View File

@ -87,24 +87,6 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
return
}
// CHECK-LABEL: func private @parallel_compute_fn(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
// CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK: arith.select
// CHECK: scf.for %[[J:arg[0-9]+]]
// CHECK: memref.store
// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
@ -124,3 +106,21 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
// CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK-NOT: arith.select
// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
// CHECK-LABEL: func private @parallel_compute_fn(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
// CHECK: scf.for %[[I:arg[0-9]+]]
// CHECK: arith.select
// CHECK: scf.for %[[J:arg[0-9]+]]
// CHECK: memref.store