forked from OSchip/llvm-project
[mlir] Async: update condition for dispatching block-aligned compute function
+ compare block size with the unrollable inner dimension + reduce nesting in the code and simplify a bit IR building Reviewed By: cota Differential Revision: https://reviews.llvm.org/D120075
This commit is contained in:
parent
fc0aa8424c
commit
beff16f7bd
|
@ -779,10 +779,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to
|
||||
// unroll the loops. The constant `512` is arbitrary, it should depend on
|
||||
// how many iterations LLVM will typically decide to unroll.
|
||||
static constexpr int64_t maxIterations = 512;
|
||||
static constexpr int64_t maxUnrollableIterations = 512;
|
||||
|
||||
// The number of inner loops with statically known number of iterations less
|
||||
// than the `maxIterations` value.
|
||||
// than the `maxUnrollableIterations` value.
|
||||
int numUnrollableLoops = 0;
|
||||
|
||||
auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
|
||||
|
@ -796,7 +796,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
numIterations[i] = tripCount * innerIterations;
|
||||
|
||||
// Update the number of inner loops that we can potentially unroll.
|
||||
if (innerIterations > 0 && innerIterations <= maxIterations)
|
||||
if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
|
||||
numUnrollableLoops++;
|
||||
}
|
||||
|
||||
|
@ -856,9 +856,6 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
|
||||
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
|
||||
|
||||
ParallelComputeFunction notUnrollableParallelComputeFunction =
|
||||
createParallelComputeFunction(op, staticBounds, 0, rewriter);
|
||||
|
||||
// Dispatch parallel compute function using async recursive work splitting,
|
||||
// or by submitting compute task sequentially from a caller thread.
|
||||
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
|
||||
|
@ -869,42 +866,47 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
// Compute the number of parallel compute blocks.
|
||||
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
|
||||
|
||||
// Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
|
||||
bool staticShouldUnroll = numUnrollableLoops > 0;
|
||||
auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||
// Dispatch parallel compute function without hints to unroll inner loops.
|
||||
auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||
ParallelComputeFunction compute =
|
||||
createParallelComputeFunction(op, staticBounds, 0, rewriter);
|
||||
|
||||
ImplicitLocOpBuilder b(loc, nestedBuilder);
|
||||
doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
|
||||
blockSize, blockCount, tripCounts);
|
||||
doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
|
||||
b.create<scf::YieldOp>();
|
||||
};
|
||||
|
||||
if (staticShouldUnroll) {
|
||||
Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::sge, blockSize,
|
||||
b.create<arith::ConstantIndexOp>(maxIterations));
|
||||
// Dispatch parallel compute function with hints for unrolling inner loops.
|
||||
auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||
ParallelComputeFunction compute = createParallelComputeFunction(
|
||||
op, staticBounds, numUnrollableLoops, rewriter);
|
||||
|
||||
ParallelComputeFunction unrollableParallelComputeFunction =
|
||||
createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
|
||||
rewriter);
|
||||
ImplicitLocOpBuilder b(loc, nestedBuilder);
|
||||
// Align the block size to be a multiple of the statically known
|
||||
// number of iterations in the inner loops.
|
||||
Value numIters = b.create<arith::ConstantIndexOp>(
|
||||
numIterations[op.getNumLoops() - numUnrollableLoops]);
|
||||
Value alignedBlockSize = b.create<arith::MulIOp>(
|
||||
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
|
||||
doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
|
||||
tripCounts);
|
||||
b.create<scf::YieldOp>();
|
||||
};
|
||||
|
||||
auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||
ImplicitLocOpBuilder b(loc, nestedBuilder);
|
||||
// Align the block size to be a multiple of the statically known
|
||||
// number of iterations in the inner loops.
|
||||
Value numIters = b.create<arith::ConstantIndexOp>(
|
||||
numIterations[op.getNumLoops() - numUnrollableLoops]);
|
||||
Value alignedBlockSize = b.create<arith::MulIOp>(
|
||||
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
|
||||
doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
|
||||
alignedBlockSize, blockCount, tripCounts);
|
||||
b.create<scf::YieldOp>();
|
||||
};
|
||||
// Dispatch to block aligned compute function only if the computed block
|
||||
// size is larger than the number of iterations in the unrollable inner
|
||||
// loops, because otherwise it can reduce the available parallelism.
|
||||
if (numUnrollableLoops > 0) {
|
||||
Value numIters = b.create<arith::ConstantIndexOp>(
|
||||
numIterations[op.getNumLoops() - numUnrollableLoops]);
|
||||
Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::sge, blockSize, numIters);
|
||||
|
||||
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
|
||||
dispatchNotUnrollable);
|
||||
b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
|
||||
dispatchBlockAligned, dispatchDefault);
|
||||
b.create<scf::YieldOp>();
|
||||
} else {
|
||||
dispatchNotUnrollable(b, loc);
|
||||
dispatchDefault(b, loc);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -87,24 +87,6 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @parallel_compute_fn(
|
||||
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
|
||||
// CHECK-SAME: ) {
|
||||
// CHECK: scf.for %[[I:arg[0-9]+]]
|
||||
// CHECK: arith.select
|
||||
// CHECK: scf.for %[[J:arg[0-9]+]]
|
||||
// CHECK: memref.store
|
||||
|
||||
// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
|
||||
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
|
||||
|
@ -124,3 +106,21 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
|
|||
// CHECK: scf.for %[[I:arg[0-9]+]]
|
||||
// CHECK-NOT: arith.select
|
||||
// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
|
||||
|
||||
// CHECK-LABEL: func private @parallel_compute_fn(
|
||||
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
|
||||
// CHECK-SAME: ) {
|
||||
// CHECK: scf.for %[[I:arg[0-9]+]]
|
||||
// CHECK: arith.select
|
||||
// CHECK: scf.for %[[J:arg[0-9]+]]
|
||||
// CHECK: memref.store
|
||||
|
|
Loading…
Reference in New Issue