forked from OSchip/llvm-project
[mlir:Async] Remove async operations if it is statically known that the parallel operation has a single compute block
Depends On D104850 Add a test that verifies that canonicalization removes all async overheads if it is statically known that the scf.parallel operation will be computed using a single block. Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D104891
This commit is contained in:
parent
a37f558682
commit
a8f819c6d8
|
@ -20,6 +20,7 @@
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
|
|
|
@ -177,6 +177,8 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
|
||||||
let arguments = (ins Index:$size);
|
let arguments = (ins Index:$size);
|
||||||
let results = (outs Async_GroupType:$result);
|
let results = (outs Async_GroupType:$result);
|
||||||
|
|
||||||
|
let hasCanonicalizeMethod = 1;
|
||||||
|
|
||||||
let assemblyFormat = "$size `:` type($result) attr-dict";
|
let assemblyFormat = "$size `:` type($result) attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -245,6 +245,36 @@ static LogicalResult verify(ExecuteOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
/// CreateGroupOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
// Find all `await_all` users of the group.
|
||||||
|
llvm::SmallVector<AwaitAllOp> awaitAllUsers;
|
||||||
|
|
||||||
|
auto isAwaitAll = [&](Operation *op) -> bool {
|
||||||
|
if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
|
||||||
|
awaitAllUsers.push_back(awaitAll);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if all users of the group are `await_all` operations.
|
||||||
|
if (!llvm::all_of(op->getUsers(), isAwaitAll))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// If group is only awaited without adding anything to it, we can safely erase
|
||||||
|
// the create operation and all users.
|
||||||
|
for (AwaitAllOp awaitAll : awaitAllUsers)
|
||||||
|
rewriter.eraseOp(awaitAll);
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
/// AwaitOp
|
/// AwaitOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -513,18 +513,48 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
|
||||||
Value groupSize = b.create<SubIOp>(blockCount, c1);
|
Value groupSize = b.create<SubIOp>(blockCount, c1);
|
||||||
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
|
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
|
||||||
|
|
||||||
// Pack the async dispath function operands to launch the work splitting.
|
// Appends operands shared by async dispatch and parallel compute functions to
|
||||||
SmallVector<Value> asyncDispatchOperands = {group, c0, blockCount, blockSize};
|
// the given operands vector.
|
||||||
asyncDispatchOperands.append(tripCounts);
|
auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
|
||||||
asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end());
|
operands.append(tripCounts);
|
||||||
asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end());
|
operands.append(op.lowerBound().begin(), op.lowerBound().end());
|
||||||
asyncDispatchOperands.append(op.step().begin(), op.step().end());
|
operands.append(op.upperBound().begin(), op.upperBound().end());
|
||||||
asyncDispatchOperands.append(parallelComputeFunction.captures);
|
operands.append(op.step().begin(), op.step().end());
|
||||||
|
operands.append(parallelComputeFunction.captures);
|
||||||
|
};
|
||||||
|
|
||||||
// Launch async dispatch function for [0, blockCount) range.
|
// Check if the block size is one, in this case we can skip the async dispatch
|
||||||
b.create<CallOp>(asyncDispatchFunction.sym_name(),
|
// completely. If this will be known statically, then canonicalization will
|
||||||
asyncDispatchFunction.getCallableResults(),
|
// erase async group operations.
|
||||||
asyncDispatchOperands);
|
Value isSingleBlock = b.create<CmpIOp>(CmpIPredicate::eq, blockCount, c1);
|
||||||
|
|
||||||
|
auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||||
|
ImplicitLocOpBuilder nb(loc, nestedBuilder);
|
||||||
|
|
||||||
|
// Call parallel compute function for the single block.
|
||||||
|
SmallVector<Value> operands = {c0, blockSize};
|
||||||
|
appendBlockComputeOperands(operands);
|
||||||
|
|
||||||
|
nb.create<CallOp>(parallelComputeFunction.func.sym_name(),
|
||||||
|
parallelComputeFunction.func.getCallableResults(),
|
||||||
|
operands);
|
||||||
|
nb.create<scf::YieldOp>();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
|
||||||
|
ImplicitLocOpBuilder nb(loc, nestedBuilder);
|
||||||
|
|
||||||
|
// Launch async dispatch function for [0, blockCount) range.
|
||||||
|
SmallVector<Value> operands = {group, c0, blockCount, blockSize};
|
||||||
|
appendBlockComputeOperands(operands);
|
||||||
|
|
||||||
|
nb.create<CallOp>(asyncDispatchFunction.sym_name(),
|
||||||
|
asyncDispatchFunction.getCallableResults(), operands);
|
||||||
|
nb.create<scf::YieldOp>();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Dispatch either single block compute function, or launch async dispatch.
|
||||||
|
b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
|
||||||
|
|
||||||
// Wait for the completion of all parallel compute operations.
|
// Wait for the completion of all parallel compute operations.
|
||||||
b.create<AwaitAllOp>(group);
|
b.create<AwaitAllOp>(group);
|
||||||
|
|
|
@ -3,8 +3,13 @@
|
||||||
|
|
||||||
// CHECK-LABEL: @loop_1d
|
// CHECK-LABEL: @loop_1d
|
||||||
func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
|
func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[GROUP:.*]] = async.create_group
|
// CHECK: %[[GROUP:.*]] = async.create_group
|
||||||
// CHECK: call @async_dispatch_fn
|
// CHECK: scf.if {{.*}} {
|
||||||
|
// CHECK: call @parallel_compute_fn(%[[C0]]
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: call @async_dispatch_fn
|
||||||
|
// CHECK: }
|
||||||
// CHECK: async.await_all %[[GROUP]]
|
// CHECK: async.await_all %[[GROUP]]
|
||||||
scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
|
scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
|
||||||
%one = constant 1.0 : f32
|
%one = constant 1.0 : f32
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
// RUN: mlir-opt %s \
|
||||||
|
// RUN: -async-parallel-for=async-dispatch=true \
|
||||||
|
// RUN: -canonicalize -inline -symbol-dce \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// RUN: mlir-opt %s \
|
||||||
|
// RUN: -async-parallel-for=async-dispatch=false \
|
||||||
|
// RUN: -canonicalize -inline -symbol-dce \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// Check that if we statically know that the parallel operation has a single
|
||||||
|
// block then all async operations will be canonicalized away and we will
|
||||||
|
// end up with a single synchonous compute function call.
|
||||||
|
|
||||||
|
// CHECK-LABEL: @loop_1d(
|
||||||
|
// CHECK: %[[MEMREF:.*]]: memref<?xf32>
|
||||||
|
func @loop_1d(%arg0: memref<?xf32>) {
|
||||||
|
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
|
||||||
|
// CHECK-DAG: %[[ONE:.*]] = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]]
|
||||||
|
// CHECK: memref.store %[[ONE]], %[[MEMREF]][%[[I]]]
|
||||||
|
%lb = constant 0 : index
|
||||||
|
%ub = constant 100 : index
|
||||||
|
%st = constant 1 : index
|
||||||
|
scf.parallel (%i) = (%lb) to (%ub) step (%st) {
|
||||||
|
%one = constant 1.0 : f32
|
||||||
|
memref.store %one, %arg0[%i] : memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue