[mlir] Async: lower SCF operations into CFG inside coroutines

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D106747
This commit is contained in:
Eugene Zhulenev 2021-07-24 06:51:15 -07:00
parent c63dbd8501
commit de7a4e53a2
4 changed files with 59 additions and 1 deletions

View File

@ -12,8 +12,10 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
@ -571,10 +573,22 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
<< " functions built from async.execute operations\n";
});
// Returns true if operation is inside the coroutine.
auto isInCoroutine = [&](Operation *op) -> bool {
auto parentFunc = op->getParentOfType<FuncOp>();
return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
};
// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
RewritePatternSet asyncPatterns(ctx);
// Conversion to async runtime augments original CFG with the coroutine CFG,
// and we have to make sure that structured control flow operations with async
// operations in nested regions will be converted to branch-based control flow
// before we add the coroutine basic blocks.
populateLoopToStdConversionPatterns(asyncPatterns);
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
@ -591,12 +605,22 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
// Decide if structured control flow has to be lowered to branch-based CFG.
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
auto walkResult = op->walk([&](Operation *nested) {
bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
: WalkResult::advance();
});
return !walkResult.wasInterrupted();
});
runtimeTarget.addLegalOp<BranchOp, CondBranchOp>();
// Assertions must be converted to runtime errors inside async functions.
runtimeTarget.addDynamicallyLegalOp<AssertOp>([&](AssertOp op) -> bool {
auto func = op->getParentOfType<FuncOp>();
return outlinedFunctions.find(func) == outlinedFunctions.end();
});
runtimeTarget.addLegalOp<CondBranchOp>();
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {

View File

@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
MLIRAsync
MLIRPass
MLIRSCF
MLIRSCFToStandard
MLIRStandard
MLIRTransforms
MLIRTransformUtils

View File

@ -374,3 +374,35 @@ func @execute_asserttion(%arg0: i1) {
// CHECK: ^[[SUSPEND]]:
// CHECK: async.coro.end %[[HDL]]
// CHECK: return %[[TOKEN]]
// -----
// Structured control flow operations with async operations in the body must be
// lowered to branch-based control flow to enable coroutine CFG rewrite.
// CHECK-LABEL: @lower_scf_to_cfg
func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
%token0 = async.execute { async.yield }
%token1 = async.execute {
scf.if %arg2 {
async.await %token0 : !async.token
} else {
async.await %token0 : !async.token
}
async.yield
}
return
}
// Function outlined from the first async.execute operation.
// CHECK-LABEL: func private @async_execute_fn(
// CHECK-SAME: -> !async.token
// Function outlined from the second async.execute operation.
// CHECK-LABEL: func private @async_execute_fn_0(
// CHECK: %[[TOKEN:.*]]: !async.token
// CHECK: %[[FLAG:.*]]: i1
// CHECK-SAME: -> !async.token
// Check that structured control flow lowered to CFG.
// CHECK-NOT: scf.if
// CHECK: cond_br %[[FLAG]]

View File

@ -1791,6 +1791,7 @@ cc_library(
":IR",
":Pass",
":SCFDialect",
":SCFToStandard",
":StandardOps",
":Support",
":TransformUtils",