forked from OSchip/llvm-project
[mlir] Async: lower SCF operations into CFG inside coroutines
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D106747
This commit is contained in:
parent
c63dbd8501
commit
de7a4e53a2
|
@ -12,8 +12,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/Async/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
|
@ -571,10 +573,22 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
|||
<< " functions built from async.execute operations\n";
|
||||
});
|
||||
|
||||
// Returns true if operation is inside the coroutine.
|
||||
auto isInCoroutine = [&](Operation *op) -> bool {
|
||||
auto parentFunc = op->getParentOfType<FuncOp>();
|
||||
return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
|
||||
};
|
||||
|
||||
// Lower async operations to async.runtime operations.
|
||||
MLIRContext *ctx = module->getContext();
|
||||
RewritePatternSet asyncPatterns(ctx);
|
||||
|
||||
// Conversion to async runtime augments original CFG with the coroutine CFG,
|
||||
// and we have to make sure that structured control flow operations with async
|
||||
// operations in nested regions will be converted to branch-based control flow
|
||||
// before we add the coroutine basic blocks.
|
||||
populateLoopToStdConversionPatterns(asyncPatterns);
|
||||
|
||||
// Async lowering does not use type converter because it must preserve all
|
||||
// types for async.runtime operations.
|
||||
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
|
||||
|
@ -591,12 +605,22 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
|||
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
|
||||
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
|
||||
|
||||
// Decide if structured control flow has to be lowered to branch-based CFG.
|
||||
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
|
||||
auto walkResult = op->walk([&](Operation *nested) {
|
||||
bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
|
||||
return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
|
||||
: WalkResult::advance();
|
||||
});
|
||||
return !walkResult.wasInterrupted();
|
||||
});
|
||||
runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
|
||||
|
||||
// Assertions must be converted to runtime errors inside async functions.
|
||||
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
|
||||
auto func = op->getParentOfType<FuncOp>();
|
||||
return outlinedFunctions.find(func) == outlinedFunctions.end();
|
||||
});
|
||||
runtimeTarget.addLegalOp<CondBranchOp>();
|
||||
|
||||
if (failed(applyPartialConversion(module, runtimeTarget,
|
||||
std::move(asyncPatterns)))) {
|
||||
|
|
|
@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
|
|||
MLIRAsync
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRSCFToStandard
|
||||
MLIRStandard
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
|
|
|
@ -374,3 +374,35 @@ func @execute_asserttion(%arg0: i1) {
|
|||
// CHECK: ^[[SUSPEND]]:
|
||||
// CHECK: async.coro.end %[[HDL]]
|
||||
// CHECK: return %[[TOKEN]]
|
||||
|
||||
// -----
|
||||
// Structured control flow operations with async operations in the body must be
|
||||
// lowered to branch-based control flow to enable coroutine CFG rewrite.
|
||||
|
||||
// CHECK-LABEL: @lower_scf_to_cfg
|
||||
func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
|
||||
%token0 = async.execute { async.yield }
|
||||
%token1 = async.execute {
|
||||
scf.if %arg2 {
|
||||
async.await %token0 : !async.token
|
||||
} else {
|
||||
async.await %token0 : !async.token
|
||||
}
|
||||
async.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Function outlined from the first async.execute operation.
|
||||
// CHECK-LABEL: func private @async_execute_fn(
|
||||
// CHECK-SAME: -> !async.token
|
||||
|
||||
// Function outlined from the second async.execute operation.
|
||||
// CHECK-LABEL: func private @async_execute_fn_0(
|
||||
// CHECK: %[[TOKEN:.*]]: !async.token
|
||||
// CHECK: %[[FLAG:.*]]: i1
|
||||
// CHECK-SAME: -> !async.token
|
||||
|
||||
// Check that structured control flow lowered to CFG.
|
||||
// CHECK-NOT: scf.if
|
||||
// CHECK: cond_br %[[FLAG]]
|
||||
|
|
|
@ -1791,6 +1791,7 @@ cc_library(
|
|||
":IR",
|
||||
":Pass",
|
||||
":SCFDialect",
|
||||
":SCFToStandard",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
|
|
Loading…
Reference in New Issue