forked from OSchip/llvm-project
[mlir] Async: clone constants into async.execute functions and parallel compute functions
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D107007
This commit is contained in:
parent
84602f98c6
commit
b537c5b414
|
@ -190,6 +190,10 @@ createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
|
|||
|
||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
|
||||
// Make sure that all constants will be inside the parallel operation body to
|
||||
// reduce the number of parallel compute function arguments.
|
||||
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
|
||||
|
||||
ParallelComputeFunctionType computeFuncType =
|
||||
getParallelComputeFunctionType(op, rewriter);
|
||||
|
||||
|
|
|
@ -235,6 +235,10 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
|||
MLIRContext *ctx = module.getContext();
|
||||
Location loc = execute.getLoc();
|
||||
|
||||
// Make sure that all constants will be inside the outlined async function to
|
||||
// reduce the number of function arguments.
|
||||
cloneConstantsIntoTheRegion(execute.body());
|
||||
|
||||
// Collect all outlined function inputs.
|
||||
SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
|
||||
execute.dependencies().end());
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms
|
|||
AsyncRuntimeRefCounting.cpp
|
||||
AsyncRuntimeRefCountingOpt.cpp
|
||||
AsyncToAsyncRuntime.cpp
|
||||
PassDetail.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
//===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void mlir::async::cloneConstantsIntoTheRegion(Region ®ion) {
|
||||
OpBuilder builder(®ion);
|
||||
cloneConstantsIntoTheRegion(region, builder);
|
||||
}
|
||||
|
||||
void mlir::async::cloneConstantsIntoTheRegion(Region ®ion,
|
||||
OpBuilder &builder) {
|
||||
// Values implicitly captured by the region.
|
||||
llvm::SetVector<Value> captures;
|
||||
getUsedValuesDefinedAbove(region, region, captures);
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.setInsertionPointToStart(®ion.front());
|
||||
|
||||
// Clone ConstantLike operations into the region.
|
||||
for (Value capture : captures) {
|
||||
Operation *op = capture.getDefiningOp();
|
||||
if (!op || !op->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
|
||||
Operation *cloned = builder.clone(*op);
|
||||
|
||||
for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) {
|
||||
Value orig = std::get<0>(tuple);
|
||||
Value replacement = std::get<1>(tuple);
|
||||
replaceAllUsesInRegionWith(orig, replacement, region);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -25,6 +25,24 @@ class SCFDialect;
|
|||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/Async/Passes.h.inc"
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Utility functions shared by Async Transformations.
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
// Forward declarations.
|
||||
class OpBuilder;
|
||||
|
||||
namespace async {
|
||||
|
||||
/// Clone ConstantLike operations that are defined above the given region and
|
||||
/// have users in the region into the region entry block. We do that to reduce
|
||||
/// the number of function arguments when we outline `async.execute` and
|
||||
/// `scf.parallel` operations body into functions.
|
||||
void cloneConstantsIntoTheRegion(Region ®ion);
|
||||
void cloneConstantsIntoTheRegion(Region ®ion, OpBuilder &builder);
|
||||
|
||||
} // namespace async
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // DIALECT_ASYNC_TRANSFORMS_PASSDETAIL_H_
|
||||
|
|
|
@ -89,13 +89,14 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
|
|||
}
|
||||
|
||||
// Function outlined from the inner async.execute operation.
|
||||
// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
|
||||
// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
|
||||
// CHECK-SAME: -> !llvm.ptr<i8>
|
||||
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
|
||||
// CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin
|
||||
// CHECK: call @mlirAsyncRuntimeExecute
|
||||
// CHECK: llvm.intr.coro.suspend
|
||||
// CHECK: memref.store %arg0, %arg1[%arg2] : memref<1xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: memref.store %arg0, %arg1[%[[C0]]] : memref<1xf32>
|
||||
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
|
||||
|
||||
// Function outlined from the outer async.execute operation.
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
// RUN: mlir-opt %s \
|
||||
// RUN: -async-parallel-for=async-dispatch=true \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// RUN: mlir-opt %s \
|
||||
// RUN: -async-parallel-for=async-dispatch=false \
|
||||
// RUN: -canonicalize -inline -symbol-dce \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// Check that constants defined outside of the `scf.parallel` body will be
|
||||
// sunk into the parallel compute function to avoid blowing up the number
|
||||
// of parallel compute function arguments.
|
||||
|
||||
// CHECK-LABEL: func @clone_constant(
|
||||
func @clone_constant(%arg0: memref<?xf32>, %lb: index, %ub: index, %st: index) {
|
||||
%one = constant 1.0 : f32
|
||||
|
||||
scf.parallel (%i) = (%lb) to (%ub) step (%st) {
|
||||
memref.store %one, %arg0[%i] : memref<?xf32>
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @parallel_compute_fn(
|
||||
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[TRIP_COUNT:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[LB:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[UB:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[STEP:arg[0-9]+]]: index,
|
||||
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?xf32>
|
||||
// CHECK-SAME: ) {
|
||||
// CHECK: %[[CST:.*]] = constant 1.0{{.*}} : f32
|
||||
// CHECK: scf.for
|
||||
// CHECK: memref.store %[[CST]], %[[MEMREF]]
|
|
@ -406,3 +406,26 @@ func @lower_scf_to_cfg(%arg0: f32, %arg1: memref<1xf32>, %arg2: i1) {
|
|||
// Check that structured control flow lowered to CFG.
|
||||
// CHECK-NOT: scf.if
|
||||
// CHECK: cond_br %[[FLAG]]
|
||||
|
||||
// -----
|
||||
// Constants captured by the async.execute region should be cloned into the
|
||||
// outline async execute function.
|
||||
|
||||
// CHECK-LABEL: @clone_constants
|
||||
func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%token = async.execute {
|
||||
memref.store %arg0, %arg1[%c0] : memref<1xf32>
|
||||
async.yield
|
||||
}
|
||||
async.await %token : !async.token
|
||||
return
|
||||
}
|
||||
|
||||
// Function outlined from the async.execute operation.
|
||||
// CHECK-LABEL: func private @async_execute_fn(
|
||||
// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32,
|
||||
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32>
|
||||
// CHECK-SAME: ) -> !async.token
|
||||
// CHECK: %[[CST:.*]] = constant 0 : index
|
||||
// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]
|
||||
|
|
Loading…
Reference in New Issue