[mlir] Async: clone constants into async.execute functions and parallel compute functions

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D107007
This commit is contained in:
Eugene Zhulenev 2021-08-02 08:40:17 -07:00
parent 84602f98c6
commit b537c5b414
8 changed files with 132 additions and 2 deletions

View File

@ -190,6 +190,10 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
ModuleOp module = op->getParentOfType<ModuleOp>();
// Make sure that all constants will be inside the parallel operation body to
// reduce the number of parallel compute function arguments.
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
ParallelComputeFunctionType computeFuncType =
getParallelComputeFunctionType(op, rewriter);

View File

@ -235,6 +235,10 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
MLIRContext *ctx = module.getContext();
Location loc = execute.getLoc();
// Make sure that all constants will be inside the outlined async function to
// reduce the number of function arguments.
cloneConstantsIntoTheRegion(execute.body());
// Collect all outlined function inputs.
SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());

View File

@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
AsyncRuntimeRefCounting.cpp
AsyncRuntimeRefCountingOpt.cpp
AsyncToAsyncRuntime.cpp
PassDetail.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async

View File

@ -0,0 +1,43 @@
//===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/RegionUtils.h"
using namespace mlir;
void mlir::async::cloneConstantsIntoTheRegion(Region &region) {
OpBuilder builder(&region);
cloneConstantsIntoTheRegion(region, builder);
}
void mlir::async::cloneConstantsIntoTheRegion(Region &region,
OpBuilder &builder) {
// Values implicitly captured by the region.
llvm::SetVector<Value> captures;
getUsedValuesDefinedAbove(region, region, captures);
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&region.front());
// Clone ConstantLike operations into the region.
for (Value capture : captures) {
Operation *op = capture.getDefiningOp();
if (!op || !op->hasTrait<OpTrait::ConstantLike>())
continue;
Operation *cloned = builder.clone(*op);
for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) {
Value orig = std::get<0>(tuple);
Value replacement = std::get<1>(tuple);
replaceAllUsesInRegionWith(orig, replacement, region);
}
}
}

View File

@ -25,6 +25,24 @@ class SCFDialect;
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Async/Passes.h.inc"
// -------------------------------------------------------------------------- //
// Utility functions shared by Async Transformations.
// -------------------------------------------------------------------------- //
// Forward declarations.
class OpBuilder;
namespace async {
/// Clone ConstantLike operations that are defined above the given region and
/// have users in the region into the region entry block. We do that to reduce
/// the number of function arguments when we outline `async.execute` and
/// `scf.parallel` operations body into functions.
void cloneConstantsIntoTheRegion(Region &region);
void cloneConstantsIntoTheRegion(Region &region, OpBuilder &builder);
} // namespace async
} // namespace mlir
#endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_

View File

@ -89,13 +89,14 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
}
// Function outlined from the inner async.execute operation.
// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
// CHECK-SAME: -> !llvm.ptr<i8>
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
// CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
// CHECK: call @mlirAsyncRuntimeExecute
// CHECK: llvm.intr.coro.suspend
// CHECK: memref.store %arg0, %arg1[%arg2] : memref<1xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: memref.store %arg0, %arg1[%[[C0]]] : memref<1xf32>
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
// Function outlined from the outer async.execute operation.

View File

@ -0,0 +1,36 @@
// RUN: mlir-opt %s \
// RUN: -async-parallel-for=async-dispatch=true \
// RUN: | FileCheck %s
// RUN: mlir-opt %s \
// RUN: -async-parallel-for=async-dispatch=false \
// RUN: -canonicalize -inline -symbol-dce \
// RUN: | FileCheck %s
// Check that constants defined outside of the `scf.parallel` body will be
// sunk into the parallel compute function to avoid blowing up the number
// of parallel compute function arguments.
// CHECK-LABEL: func @clone_constant(
func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
%one = constant 1.0 : f32
scf.parallel (%i) = (%lb) to (%ub) step (%st) {
memref.store %one, %arg0[%i] : memref<?xf32>
}
return
}
// CHECK-LABEL: func private @parallel_compute_fn(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index,
// CHECK-SAME: %[[LB:arg[0-9]+]]: index,
// CHECK-SAME: %[[UB:arg[0-9]+]]: index,
// CHECK-SAME: %[[STEP:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?xf32>
// CHECK-SAME: ) {
// CHECK: %[[CST:.*]] = constant 1.0{{.*}} : f32
// CHECK: scf.for
// CHECK: memref.store %[[CST]], %[[MEMREF]]

View File

@ -406,3 +406,26 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
// Check that structured control flow lowered to CFG.
// CHECK-NOT: scf.if
// CHECK: cond_br %[[FLAG]]
// -----
// Constants captured by the async.execute region should be cloned into the
// outline async execute function.
// CHECK-LABEL: @clone_constants
func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
%c0 = constant 0 : index
%token = async.execute {
memref.store %arg0, %arg1[%c0] : memref<1xf32>
async.yield
}
async.await %token : !async.token
return
}
// Function outlined from the async.execute operation.
// CHECK-LABEL: func private @async_execute_fn(
// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32>
// CHECK-SAME: ) -> !async.token
// CHECK: %[[CST:.*]] = constant 0 : index
// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]