forked from OSchip/llvm-project
[async] Get the number of worker threads from the runtime.
Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D117751
This commit is contained in:
parent
2afc8be2fa
commit
149311b405
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
include "mlir/Dialect/Async/IR/AsyncDialect.td"
|
||||
include "mlir/Dialect/Async/IR/AsyncTypes.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -529,4 +530,17 @@ def Async_RuntimeDropRefOp : Async_Op<"runtime.drop_ref"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Async_RuntimeNumWorkerThreadsOp :
|
||||
Async_Op<"runtime.num_worker_threads",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "gets the number of threads in the threadpool from the runtime";
|
||||
let description = [{
|
||||
The `async.runtime.num_worker_threads` operation gets the number of threads
|
||||
in the threadpool from the runtime.
|
||||
}];
|
||||
|
||||
let results = (outs Index:$result);
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
#endif // ASYNC_OPS
|
||||
|
|
|
@ -25,7 +25,8 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
|
|||
|
||||
Option<"numWorkerThreads", "num-workers",
|
||||
"int32_t", /*default=*/"8",
|
||||
"The number of available workers to execute async operations.">,
|
||||
"The number of available workers to execute async operations. If `-1` "
|
||||
"the value will be retrieved from the runtime.">,
|
||||
|
||||
Option<"minTaskSize", "min-task-size",
|
||||
"int32_t", /*default=*/"1000",
|
||||
|
|
|
@ -123,6 +123,9 @@ extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle,
|
|||
extern "C" void
|
||||
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume);
|
||||
|
||||
// Returns the current number of available worker threads in the threadpool.
|
||||
extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Small async runtime support library for testing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -59,6 +59,8 @@ static constexpr const char *kAwaitValueAndExecute =
|
|||
"mlirAsyncRuntimeAwaitValueAndExecute";
|
||||
static constexpr const char *kAwaitAllAndExecute =
|
||||
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
|
||||
static constexpr const char *kGetNumWorkerThreads =
|
||||
"mlirAsyncRuntimGetNumWorkerThreads";
|
||||
|
||||
namespace {
|
||||
/// Async Runtime API function types.
|
||||
|
@ -181,6 +183,10 @@ struct AsyncAPI {
|
|||
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
|
||||
}
|
||||
|
||||
static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
|
||||
return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
|
||||
}
|
||||
|
||||
// Auxiliary coroutine resume intrinsic wrapper.
|
||||
static Type resumeFunctionType(MLIRContext *ctx) {
|
||||
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
||||
|
@ -226,6 +232,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
|||
AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
|
||||
addFuncDecl(kAwaitAllAndExecute,
|
||||
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
|
||||
addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -879,6 +886,30 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Convert async.runtime.num_worker_threads to the corresponding runtime API
|
||||
// call.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class RuntimeNumWorkerThreadsOpLowering
|
||||
: public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// Replace with a runtime API function call.
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, kGetNumWorkerThreads,
|
||||
rewriter.getIndexType());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async reference counting ops lowering (`async.runtime.add_ref` and
|
||||
// `async.runtime.drop_ref` to the corresponding API calls).
|
||||
|
@ -984,8 +1015,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
|
|||
patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
|
||||
RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
|
||||
RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
|
||||
RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
|
||||
RuntimeDropRefOpLowering>(converter, ctx);
|
||||
RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
|
||||
RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
|
||||
ctx);
|
||||
|
||||
// Lower async.runtime operations that rely on LLVM type converter to convert
|
||||
// from async value payload type to the LLVM type.
|
||||
|
|
|
@ -9,5 +9,6 @@ add_mlir_dialect_library(MLIRAsync
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
)
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
|
||||
|
@ -799,19 +800,53 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
|
|||
numUnrollableLoops++;
|
||||
}
|
||||
|
||||
// With large number of threads the value of creating many compute blocks
|
||||
// is reduced because the problem typically becomes memory bound. For small
|
||||
// number of threads it helps with stragglers.
|
||||
float overshardingFactor = numWorkerThreads <= 4 ? 8.0
|
||||
: numWorkerThreads <= 8 ? 4.0
|
||||
: numWorkerThreads <= 16 ? 2.0
|
||||
: numWorkerThreads <= 32 ? 1.0
|
||||
: numWorkerThreads <= 64 ? 0.8
|
||||
: 0.6;
|
||||
Value numWorkerThreadsVal;
|
||||
if (numWorkerThreads >= 0)
|
||||
numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
|
||||
else
|
||||
numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
|
||||
|
||||
// Do not overload worker threads with too many compute blocks.
|
||||
Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
|
||||
std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
|
||||
// With large number of threads the value of creating many compute blocks
|
||||
// is reduced because the problem typically becomes memory bound. For this
|
||||
// reason we scale the number of workers using an equivalent to the
|
||||
// following logic:
|
||||
// float overshardingFactor = numWorkerThreads <= 4 ? 8.0
|
||||
// : numWorkerThreads <= 8 ? 4.0
|
||||
// : numWorkerThreads <= 16 ? 2.0
|
||||
// : numWorkerThreads <= 32 ? 1.0
|
||||
// : numWorkerThreads <= 64 ? 0.8
|
||||
// : 0.6;
|
||||
|
||||
// Pairs of non-inclusive lower end of the bracket and factor that the
|
||||
// number of workers needs to be scaled with if it falls in that bucket.
|
||||
const SmallVector<std::pair<int, float>> overshardingBrackets = {
|
||||
{4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
|
||||
const float initialOvershardingFactor = 8.0f;
|
||||
|
||||
Value scalingFactor = b.create<arith::ConstantFloatOp>(
|
||||
llvm::APFloat(initialOvershardingFactor), b.getF32Type());
|
||||
for (const std::pair<int, float> &p : overshardingBrackets) {
|
||||
Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
|
||||
Value inBracket = b.create<arith::CmpIOp>(
|
||||
arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
|
||||
Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
|
||||
llvm::APFloat(p.second), b.getF32Type());
|
||||
scalingFactor =
|
||||
b.create<SelectOp>(inBracket, bracketScalingFactor, scalingFactor);
|
||||
}
|
||||
Value numWorkersIndex =
|
||||
b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
|
||||
Value numWorkersFloat =
|
||||
b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
|
||||
Value scaledNumWorkers =
|
||||
b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
|
||||
Value scaledNumInt =
|
||||
b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
|
||||
Value scaledWorkers =
|
||||
b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
|
||||
|
||||
Value maxComputeBlocks = b.create<arith::MaxSIOp>(
|
||||
b.create<arith::ConstantIndexOp>(1), scaledWorkers);
|
||||
|
||||
// Compute parallel block size from the parallel problem size:
|
||||
// blockSize = min(tripCount,
|
||||
|
|
|
@ -438,6 +438,10 @@ extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
|
||||
return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Small async runtime support library for testing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -515,6 +519,8 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
|
|||
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
|
||||
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
|
||||
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
|
||||
exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
|
||||
&mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
|
||||
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
|
||||
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
// RUN: mlir-opt %s -split-input-file -async-parallel-for=num-workers=-1 \
|
||||
// RUN: | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: @num_worker_threads(
|
||||
// CHECK: %[[MEMREF:.*]]: memref<?xf32>
|
||||
func @num_worker_threads(%arg0: memref<?xf32>) {
|
||||
|
||||
// CHECK: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
|
||||
// CHECK: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
|
||||
// CHECK: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
|
||||
// CHECK: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
|
||||
// CHECK: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
|
||||
// CHECK: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
|
||||
// CHECK: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
|
||||
// CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
|
||||
// CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
|
||||
// CHECK: %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
|
||||
// CHECK: %[[scalingFactor4:.*]] = select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
|
||||
// CHECK: %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
|
||||
// CHECK: %[[scalingFactor8:.*]] = select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
|
||||
// CHECK: %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
|
||||
// CHECK: %[[scalingFactor16:.*]] = select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
|
||||
// CHECK: %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
|
||||
// CHECK: %[[scalingFactor32:.*]] = select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
|
||||
// CHECK: %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
|
||||
// CHECK: %[[scalingFactor64:.*]] = select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
|
||||
// CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
|
||||
// CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
|
||||
// CHECK: %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
|
||||
// CHECK: %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
|
||||
// CHECK: %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
|
||||
|
||||
%lb = arith.constant 0 : index
|
||||
%ub = arith.constant 100 : index
|
||||
%st = arith.constant 1 : index
|
||||
scf.parallel (%i) = (%lb) to (%ub) step (%st) {
|
||||
%one = arith.constant 1.0 : f32
|
||||
memref.store %one, %arg0[%i] : memref<?xf32>
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1240,6 +1240,7 @@ td_library(
|
|||
includes = ["include"],
|
||||
deps = [
|
||||
":ControlFlowInterfacesTdFiles",
|
||||
":InferTypeOpInterfaceTdFiles",
|
||||
":OpBaseTdFiles",
|
||||
":SideEffectInterfacesTdFiles",
|
||||
],
|
||||
|
@ -2140,6 +2141,7 @@ cc_library(
|
|||
":ControlFlowInterfaces",
|
||||
":Dialect",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
":SideEffectInterfaces",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
|
|
Loading…
Reference in New Issue