Remove unnecessary async group creates and awaits.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D110605
This commit is contained in:
bakhtiyar 2021-09-28 14:35:15 -07:00 committed by Eugene Zhulenev
parent 55dfab39a2
commit bdde959533
2 changed files with 13 additions and 12 deletions

View File

@ -508,12 +508,6 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
Value c0 = b.create<ConstantIndexOp>(0);
Value c1 = b.create<ConstantIndexOp>(1);
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
// executed synchronously in the caller thread.
Value groupSize = b.create<SubIOp>(blockCount, c1);
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
// Appends operands shared by async dispatch and parallel compute functions to
// the given operands vector.
auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
@ -543,6 +537,12 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
};
auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
// executed synchronously in the caller thread.
Value groupSize = b.create<SubIOp>(blockCount, c1);
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
ImplicitLocOpBuilder nb(loc, nestedBuilder);
// Launch async dispatch function for [0, blockCount) range.
@ -551,14 +551,15 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
nb.create<CallOp>(asyncDispatchFunction.sym_name(),
asyncDispatchFunction.getCallableResults(), operands);
// Wait for the completion of all parallel compute operations.
b.create<AwaitAllOp>(group);
nb.create<scf::YieldOp>();
};
// Dispatch either single block compute function, or launch async dispatch.
b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
// Wait for the completion of all parallel compute operations.
b.create<AwaitAllOp>(group);
}
// Dispatch parallel compute functions by submitting all async compute tasks

View File

@ -12,13 +12,13 @@ func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
// CHECK: scf.if %[[IS_NOOP]] {
// CHECK-NEXT: } else {
// CHECK: %[[GROUP:.*]] = async.create_group
// CHECK: scf.if {{.*}} {
// CHECK: scf.if {{.*}} {
// CHECK: call @parallel_compute_fn(%[[C0]]
// CHECK: } else {
// CHECK: %[[GROUP:.*]] = async.create_group
// CHECK: call @async_dispatch_fn
// CHECK: async.await_all %[[GROUP]]
// CHECK: }
// CHECK: async.await_all %[[GROUP]]
// CHECK: }
scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
%one = constant 1.0 : f32