From b6acbd10558d361b47aabb4939a4e6c1353a93e4 Mon Sep 17 00:00:00 2001 From: fanjibin Date: Tue, 2 Nov 2021 00:05:44 +0800 Subject: [PATCH] unified cpu thread pool --- .../kernel_compiler/cpu/adam_cpu_kernel.cc | 2 +- .../cpu/adam_delta_cpu_kernel.cc | 2 +- .../cpu/adam_weight_decay_cpu_kernel.cc | 2 +- .../cpu/arithmetic_cpu_kernel.cc | 42 ++--- .../cpu/arithmetic_cpu_kernel.h | 16 +- .../cpu/arithmetic_logic_cpu_kernel.cc | 32 ++-- .../cpu/arithmetic_logic_cpu_kernel.h | 16 +- .../cpu/arithmetic_self_cpu_kernel.cc | 150 +++++++++--------- .../cpu/arithmetic_self_cpu_kernel.h | 4 +- .../cpu/boundingbox_decode_cpu_kernel.cc | 3 +- .../cpu/boundingbox_encode_cpu_kernel.cc | 2 +- .../kernel_compiler/cpu/cast_cpu_kernel.cc | 6 +- .../cpu/check_valid_cpu_kernel.cc | 3 +- .../backend/kernel_compiler/cpu/cpu_kernel.h | 3 +- .../cpu/crop_and_resize_cpu_kernel.cc | 2 +- .../kernel_compiler/cpu/cumsum_cpu_kernel.cc | 29 +--- .../kernel_compiler/cpu/cumsum_cpu_kernel.h | 2 +- .../cpu/depthtospace_cpu_kernel.cc | 2 +- .../cpu/elu_grad_cpu_kernel.cc | 5 +- .../kernel_compiler/cpu/elu_grad_cpu_kernel.h | 2 +- .../cpu/hsigmoid_cpu_kernel.cc | 2 +- .../cpu/hsigmoid_grad_cpu_kernel.cc | 2 +- .../kernel_compiler/cpu/hswish_cpu_kernel.cc | 2 +- .../cpu/hswish_grad_cpu_kernel.cc | 2 +- .../kernel_compiler/cpu/iou_cpu_kernel.cc | 2 +- .../cpu/l2_normalize_cpu_kernel.cc | 4 +- .../cpu/l2normalize_grad_cpu_kernel.cc | 2 +- .../cpu/mkldnn/addn_cpu_kernel.cc | 4 +- .../cpu/nms_with_mask_cpu_kernel.cc | 14 +- .../kernel_compiler/cpu/pack_cpu_kernel.cc | 24 +-- .../kernel_compiler/cpu/random_cpu_kernel.cc | 33 +--- .../kernel_compiler/cpu/reduce_cpu_kernel.cc | 4 +- .../cpu/rl/buffer_append_cpu_kernel.h | 2 +- .../cpu/rl/buffer_get_cpu_kernel.h | 2 +- .../cpu/rl/buffer_sample_cpu_kernel.h | 2 +- .../kernel_compiler/cpu/rmsprop_cpu_kernel.cc | 4 +- .../cpu/roi_align_cpu_kernel.cc | 3 +- .../cpu/roi_align_grad_cpu_kernel.cc | 4 +- .../cpu/scatter_nd_cpu_kernel.cc | 9 +- .../cpu/searchsorted_cpu_kernel.cc | 4 +- .../kernel_compiler/cpu/sgd_cpu_kernel.cc | 2 +- .../cpu/smooth_l1_loss_cpu_kernel.cc | 2 +- .../cpu/spacetodepth_cpu_kernel.cc | 2 +- .../runtime/framework/actor/actor_common.cc | 19 ++- 44 files changed, 207 insertions(+), 268 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc index b9951c49f00..2262c5d1150 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc @@ -60,7 +60,7 @@ void AdamCPUKernel::LaunchAdam(const std::vector &inputs, co } } }; - CPUKernelUtils::ParallelFor(task, lens); + ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); } void AdamCPUKernel::LaunchAdamNnacl(const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc index e9853d3ec84..8b3d722187f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc @@ -54,7 +54,7 @@ void AdamDeltaCPUKernel::LaunchAdamDelta(T *delta, T *m, T *v, float lr, float b } }; } - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, this, ¶llel_search_info_); } void AdamDeltaCPUKernel::InitKernel(const CNodePtr &kernel_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc index e4c12521ab7..e32af36065b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc @@ -61,7 +61,7 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index 259c393b657..b2cfdd42daf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -62,18 +62,18 @@ void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_ } // namespace template -void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out) { auto task = [&input1, &input2, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = input1[i] + input2[i]; input1[i] = out[i]; } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto iter = base_iter; @@ -83,7 +83,7 @@ void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out) const iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template @@ -118,7 +118,7 @@ void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out) { iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_, kMaxSubSerialSize); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template @@ -152,7 +152,7 @@ void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out) { iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template @@ -203,11 +203,11 @@ void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out) { out[i] = dividend / divisor; } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto iter = base_iter; @@ -232,11 +232,11 @@ void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out) const out[i] = dividend / divisor; } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto iter = base_iter; @@ -261,11 +261,11 @@ void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out) out[i] = static_cast(floor(static_cast(dividend) / static_cast(divisor))); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto iter = base_iter; @@ -283,11 +283,11 @@ void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out) const out[i] = static_cast(x - data_div_res * y); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticCPUKernel::FloorMod(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::FloorMod(const T *input1, const T *input2, T *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto iter = base_iter; @@ -300,11 +300,11 @@ void ArithmeticCPUKernel::FloorMod(const T *input1, const T *input2, T *out) out[i] = static_cast((std::abs(res) > 1e-9) && ((res < 0.0) != (y < 0.0)) ? res + y : res); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out) { if constexpr (std::is_same_v) { auto is_power_single = [this]() { bool is_power_single = false; @@ -324,14 +324,14 @@ void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out) const auto task = [&](size_t start, size_t end) { (void)Power(input1 + start, input2, out + start, end - start, 1, 0, true); }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); return; } if (is_power_single()) { auto task = [&](size_t start, size_t end) { (void)Power(input1 + start, input2 + start, out + start, end - start, 1, 0, false); }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); return; } } @@ -348,7 +348,7 @@ void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out) const iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } else { base_iter.SetPos(0); for (size_t i = 0; i < output_size_; i++) { @@ -376,7 +376,7 @@ void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, } template -void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out) const { +void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto iter = base_iter; @@ -387,7 +387,7 @@ void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out) con iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 1b74c2e59da..38a47bd0d0d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -39,16 +39,16 @@ class ArithmeticCPUKernel : public CPUKernel { private: void InitComputeFunc(); void Sub(const T *input1, const T *input2, T *out); - void Add(const T *input1, const T *input2, T *out) const; + void Add(const T *input1, const T *input2, T *out); void Mul(const T *input1, const T *input2, T *out); void RealDiv(const T *input1, const T *input2, T *out); - void Div(const T *input1, const T *input2, T *out) const; - void FloorDiv(const T *input1, const T *input2, T *out) const; - void Mod(const T *input1, const T *input2, T *out) const; - void FloorMod(const T *input1, const T *input2, T *out) const; - void Pow(const T *input1, const T *input2, T *out) const; - void AssignAdd(T *input1, const T *input2, T *out) const; - void Atan2(const T *input1, const T *input2, T *out) const; + void Div(const T *input1, const T *input2, T *out); + void FloorDiv(const T *input1, const T *input2, T *out); + void Mod(const T *input1, const T *input2, T *out); + void FloorMod(const T *input1, const T *input2, T *out); + void Pow(const T *input1, const T *input2, T *out); + void AssignAdd(T *input1, const T *input2, T *out); + void Atan2(const T *input1, const T *input2, T *out); void SquaredDifference(const T *input1, const T *input2, T *out); using TypeComputeFunc = std::function; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc index 5893da10dad..d09d75dc3f2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc @@ -32,7 +32,7 @@ constexpr size_t kOutputsNum = 1; } // namespace template -void ArithmeticLogicCPUKernel::Less(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::Less(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); if (output_size_ > kMaxLessSerialSize) { auto task = [&](size_t start, size_t end) { @@ -45,7 +45,7 @@ void ArithmeticLogicCPUKernel::Less(const T *input1, const T *input2, bool *o iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } else { base_iter.SetPos(0); for (size_t i = 0; i < output_size_; i++) { @@ -58,7 +58,7 @@ void ArithmeticLogicCPUKernel::Less(const T *input1, const T *input2, bool *o } template -void ArithmeticLogicCPUKernel::Equal(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::Equal(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -70,11 +70,11 @@ void ArithmeticLogicCPUKernel::Equal(const T *input1, const T *input2, bool * iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticLogicCPUKernel::NotEqual(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::NotEqual(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -86,11 +86,11 @@ void ArithmeticLogicCPUKernel::NotEqual(const T *input1, const T *input2, boo iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticLogicCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -100,11 +100,11 @@ void ArithmeticLogicCPUKernel::LogicalAnd(const T *input1, const T *input2, b iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticLogicCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -114,11 +114,11 @@ void ArithmeticLogicCPUKernel::LogicalOr(const T *input1, const T *input2, bo iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticLogicCPUKernel::Greater(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::Greater(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -130,11 +130,11 @@ void ArithmeticLogicCPUKernel::Greater(const T *input1, const T *input2, bool iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticLogicCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -146,11 +146,11 @@ void ArithmeticLogicCPUKernel::GreaterEqual(const T *input1, const T *input2, iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template -void ArithmeticLogicCPUKernel::LessEqual(const T *input1, const T *input2, bool *out) const { +void ArithmeticLogicCPUKernel::LessEqual(const T *input1, const T *input2, bool *out) { BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); auto task = [&](size_t start, size_t end) { auto iter = base_iter; @@ -162,7 +162,7 @@ void ArithmeticLogicCPUKernel::LessEqual(const T *input1, const T *input2, bo iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size_); + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); } template diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h index 3ecac4214ef..cdebb3492b1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h @@ -39,14 +39,14 @@ class ArithmeticLogicCPUKernel : public CPUKernel { private: void InitComputeFunc(); - void Less(const T *input1, const T *input2, bool *out) const; - void Equal(const T *input1, const T *input2, bool *out) const; - void NotEqual(const T *input1, const T *input2, bool *out) const; - void Greater(const T *input1, const T *input2, bool *out) const; - void GreaterEqual(const T *input1, const T *input2, bool *out) const; - void LessEqual(const T *input1, const T *input2, bool *out) const; - void LogicalAnd(const T *input1, const T *input2, bool *out) const; - void LogicalOr(const T *input1, const T *input2, bool *out) const; + void Less(const T *input1, const T *input2, bool *out); + void Equal(const T *input1, const T *input2, bool *out); + void NotEqual(const T *input1, const T *input2, bool *out); + void Greater(const T *input1, const T *input2, bool *out); + void GreaterEqual(const T *input1, const T *input2, bool *out); + void LessEqual(const T *input1, const T *input2, bool *out); + void LogicalAnd(const T *input1, const T *input2, bool *out); + void LogicalOr(const T *input1, const T *input2, bool *out); using TypeComputeFunc = std::function; TypeComputeFunc compute_func_{nullptr}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index 60e336d0aef..123a417f689 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -33,17 +33,17 @@ constexpr size_t kInputsNum = 1; constexpr size_t kOutputsNum = 1; template -void Square(const T *in, T *out, size_t size) { +void Square(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = in[i] * in[i]; } }; - ParallelLaunch(task, size, kMaxSquareSerialSize); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Sign(const T *in, T *out, size_t size) { +void Sign(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { if (in[i] < 0) { @@ -55,11 +55,11 @@ void Sign(const T *in, T *out, size_t size) { } } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Neg(const T *in, T *out, size_t size) { +void Neg(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = -in[i]; @@ -68,77 +68,77 @@ void Neg(const T *in, T *out, size_t size) { ParallelLaunch(task, size, kMaxNegSerialSize); } -void LogicalNot(const bool *in, bool *out, size_t size) { +void LogicalNot(ArithmeticSelfCPUKernel *content, const bool *in, bool *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = !in[i]; } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void OnesLike(const T *, T *out, size_t size) { +void OnesLike(ArithmeticSelfCPUKernel *content, const T *, T *out, size_t size) { auto task = [&out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(1); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void ZerosLike(const T *, T *out, size_t size) { +void ZerosLike(ArithmeticSelfCPUKernel *content, const T *, T *out, size_t size) { auto task = [&out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(0); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Floor(const T *in, T *out, size_t size) { +void Floor(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(floor(in[i])); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Rint(const T *in, T *out, size_t size) { +void Rint(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(rint(in[i])); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Round(const T *in, T *out, size_t size) { +void Round(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(nearbyint(in[i])); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Reciprocal(const T *in, T *out, size_t size) { +void Reciprocal(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(1.0 / in[i]); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Gelu(const T *in, T *out, size_t size) { +void Gelu(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { auto factor_a = static_cast(0.7978845608); auto factor_b = static_cast(0.044715); @@ -149,137 +149,137 @@ void Gelu(const T *in, T *out, size_t size) { out[i] = x * (static_cast(1.0) + tanh_res) / static_cast(2.0); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Asin(const T *in, T *out, size_t size) { +void Asin(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(asin(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void ACos(const T *in, T *out, size_t size) { +void ACos(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(acos(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Atan(const T *in, T *out, size_t size) { +void Atan(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(atan(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Sin(const T *in, T *out, size_t size) { +void Sin(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(sin(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Cos(const T *in, T *out, size_t size) { +void Cos(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(cos(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Tan(const T *in, T *out, size_t size) { +void Tan(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(tan(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Sinh(const T *in, T *out, size_t size) { +void Sinh(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(sinh(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Cosh(const T *in, T *out, size_t size) { +void Cosh(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(cosh(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Asinh(const T *in, T *out, size_t size) { +void Asinh(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(asinh(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Acosh(const T *in, T *out, size_t size) { +void Acosh(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(acosh(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Atanh(const T *in, T *out, size_t size) { +void Atanh(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(atanh(static_cast(in[i]))); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Abs(const T *in, T *out, size_t size) { +void Abs(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = abs(in[i]); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template -void Sqrt(const T *in, T *out, size_t size) { +void Sqrt(ArithmeticSelfCPUKernel *content, const T *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = sqrt(in[i]); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template @@ -316,49 +316,49 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector &inpu } void ArithmeticSelfCPUKernel::LaunchLogicalNot(const std::vector &inputs, - const std::vector &outputs) const { + const std::vector &outputs) { auto *input = reinterpret_cast(inputs[0]->addr); auto *output = reinterpret_cast(outputs[0]->addr); size_t lens = outputs[0]->size / sizeof(bool); - LogicalNot(input, output, lens); + LogicalNot(this, input, output, lens); } template void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) const { + const std::vector &outputs) { const auto *input = reinterpret_cast(inputs[0]->addr); auto *output = reinterpret_cast(outputs[0]->addr); const size_t lens = outputs[0]->size / sizeof(T); - static const std::unordered_map> arithmeticSelfFuncMap{ - {prim::kPrimSquare->name(), Square}, - {prim::kPrimSign->name(), Sign}, - {prim::kPrimNeg->name(), Neg}, - {prim::kPrimAtanh->name(), Atanh}, - {prim::kPrimAcosh->name(), Acosh}, - {prim::kPrimFloor->name(), Floor}, - {prim::kPrimSin->name(), Sin}, - {prim::kPrimGeLU->name(), Gelu}, - {prim::kPrimCos->name(), Cos}, - {prim::kPrimTan->name(), Tan}, - {prim::kPrimAsin->name(), Asin}, - {prim::kPrimACos->name(), ACos}, - {prim::kPrimAtan->name(), Atan}, - {prim::kPrimSinh->name(), Sinh}, - {prim::kPrimCosh->name(), Cosh}, - {prim::kPrimAsinh->name(), Asinh}, - {prim::kPrimZerosLike->name(), ZerosLike}, - {prim::kPrimOnesLike->name(), OnesLike}, - {prim::kPrimReciprocal->name(), Reciprocal}, - {prim::kPrimRint->name(), Rint}, - {prim::kPrimRound->name(), Round}, - {prim::kPrimAbs->name(), Abs}, - {prim::kPrimSqrt->name(), Sqrt}}; + static const std::unordered_map> + arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square}, + {prim::kPrimSign->name(), Sign}, + {prim::kPrimNeg->name(), Neg}, + {prim::kPrimAtanh->name(), Atanh}, + {prim::kPrimAcosh->name(), Acosh}, + {prim::kPrimFloor->name(), Floor}, + {prim::kPrimSin->name(), Sin}, + {prim::kPrimGeLU->name(), Gelu}, + {prim::kPrimCos->name(), Cos}, + {prim::kPrimTan->name(), Tan}, + {prim::kPrimAsin->name(), Asin}, + {prim::kPrimACos->name(), ACos}, + {prim::kPrimAtan->name(), Atan}, + {prim::kPrimSinh->name(), Sinh}, + {prim::kPrimCosh->name(), Cosh}, + {prim::kPrimAsinh->name(), Asinh}, + {prim::kPrimZerosLike->name(), ZerosLike}, + {prim::kPrimOnesLike->name(), OnesLike}, + {prim::kPrimReciprocal->name(), Reciprocal}, + {prim::kPrimRint->name(), Rint}, + {prim::kPrimRound->name(), Round}, + {prim::kPrimAbs->name(), Abs}, + {prim::kPrimSqrt->name(), Sqrt}}; const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_); if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) { MS_LOG(EXCEPTION) << "ArithmeticSelfCPUKernel does not support " << kernel_name_; } - func_pair->second(input, output, lens); + func_pair->second(this, input, output, lens); } template diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 0f7f932c37c..d58889b81c9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -37,9 +37,9 @@ class ArithmeticSelfCPUKernel : public CPUKernel { private: template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs) const; + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - void LaunchLogicalNot(const std::vector &inputs, const std::vector &outputs) const; + void LaunchLogicalNot(const std::vector &inputs, const std::vector &outputs); TypeId dtype_{kTypeUnknown}; }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_decode_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_decode_cpu_kernel.cc index bf0350867af..5c57a48fddc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_decode_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_decode_cpu_kernel.cc @@ -141,8 +141,7 @@ bool BoundingBoxDecodeCPUKernel::Launch(const std::vector bboxes[right_y] = y2; } }; - CPUKernelUtils::ParallelFor(task, elem_num); - + ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_encode_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_encode_cpu_kernel.cc index c2b2a60aa36..d1a72d41b17 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_encode_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_encode_cpu_kernel.cc @@ -113,7 +113,7 @@ bool BoundingBoxEncodeCPUKernel::Launch(const std::vector deltas[right_y] = (dh - static_cast(means_[H_INDEX])) / static_cast(stds_[H_INDEX]); } }; - CPUKernelUtils::ParallelFor(task, elem_num); + ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc index 8770de6618b..7748544e5d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc @@ -30,13 +30,13 @@ constexpr size_t kCastOutputsNum = 1; } // namespace template -void Cast(const S *in, T *out, size_t size) { +void Cast(CastCPUKernel *content, const S *in, T *out, size_t size) { auto task = [&in, &out](size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = static_cast(in[i]); } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_); } template @@ -59,7 +59,7 @@ bool CastCPUKernel::Launch(const std::vector &inputs, const auto *input = reinterpret_cast(inputs[0]->addr); auto *output = reinterpret_cast(outputs[0]->addr); MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); - Cast(input, output, outputs[0]->size / sizeof(T)); + Cast(this, input, output, outputs[0]->size / sizeof(T)); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/check_valid_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/check_valid_cpu_kernel.cc index 76816b8ceda..ec780d13aa4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/check_valid_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/check_valid_cpu_kernel.cc @@ -64,8 +64,7 @@ bool CheckValidCPUKernel::Launch(const std::vector &input output[i] = !valid_flag; } }; - CPUKernelUtils::ParallelFor(task, elem_num); - + ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index f53c7c64fa9..0fbdcba623b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -141,12 +141,13 @@ class CPUKernel : public kernel::KernelMod { void InitDynamicKernel(const CNodePtr &cnode_ptr) { dynamic_kernel_ = std::make_shared(cnode_ptr); } device::DynamicKernelPtr DynamicKernel() const { return dynamic_kernel_; } + ParallelSearchInfo parallel_search_info_; + protected: virtual void InitInputOutputSize(const CNodePtr &kernel_node); std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - ParallelSearchInfo parallel_search_info_; CNodeWeakPtr cnode_ptr_; device::DynamicKernelPtr dynamic_kernel_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/crop_and_resize_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/crop_and_resize_cpu_kernel.cc index dc34e42df30..add1650657f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/crop_and_resize_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/crop_and_resize_cpu_kernel.cc @@ -214,7 +214,7 @@ bool CropAndResizeCPUKernel::Launch(const std::vector &in } } }; - CPUKernelUtils::ParallelFor(task, IntToSize(output_size_)); + ParallelLaunchAutoSearch(task, IntToSize(output_size_), this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.cc index 96870e33593..763816a524e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.cc @@ -225,35 +225,16 @@ void CumSumCPUKernel::LaunchCumSum(const T *input, T *output, T *workspace, size template void CumSumCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) const { + const std::vector &outputs) { const auto *input = reinterpret_cast(inputs[0]->addr); auto *ws = reinterpret_cast(workspace[0]->addr); auto output = reinterpret_cast(outputs[0]->addr); // multithreading size_t lens = inputs[0]->size > 0 ? static_cast(inputs[0]->size / sizeof(T)) : 1; - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - threads.reserve(thread_num); - size_t start = 0; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - (void)threads.emplace_back(std::thread(&CumSumCPUKernel::LaunchCumSum, this, input, output, ws, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } + auto task = [this, &input, &output, &ws](size_t start, size_t end) { + LaunchCumSum(input, output, ws, start, end); + }; + ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.h index 404de2d6a56..f7ba73a5f27 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cumsum_cpu_kernel.h @@ -65,7 +65,7 @@ class CumSumCPUKernel : public CPUKernel { template void LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) const; + const std::vector &outputs); template void LaunchCumSum(const T *input_addr, T *output_addr, T *ws_addr, size_t start, size_t end) const; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc index 23da6333d1d..640cce0a933 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc @@ -73,7 +73,7 @@ bool DepthToSpaceCPUKernel::Launch(const std::vector &inp } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc index fc57b224010..3af210a98d6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc @@ -51,8 +51,7 @@ bool EluGradCPUKernel::Launch(const std::vector &inputs, con } template -void EluGradCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) const { +void EluGradCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { const auto *input0 = reinterpret_cast(inputs[0]->addr); const auto *input1 = reinterpret_cast(inputs[1]->addr); auto *output = reinterpret_cast(outputs[0]->addr); @@ -64,7 +63,7 @@ void EluGradCPUKernel::LaunchKernel(const std::vector &inputs, output[i] = (input1[i] < static_cast(0)) ? input0[i] * (input1[i] + alpha) : input0[i]; } }; - CPUKernelUtils::ParallelFor(task, lens); + ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h index 5fdc08dcdbb..e25e2a08cb7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h @@ -34,7 +34,7 @@ class EluGradCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs) const; + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); private: TypeId dtype_{kTypeUnknown}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc index 404fbe7133f..2aac5422f78 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_cpu_kernel.cc @@ -59,7 +59,7 @@ bool HSigmoidCPUKernel::Launch(const std::vector &inputs, } } }; - CPUKernelUtils::ParallelFor(task, tensor_size_); + ParallelLaunchAutoSearch(task, tensor_size_, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc index 1aeecd529e3..88a0585a278 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hsigmoid_grad_cpu_kernel.cc @@ -58,7 +58,7 @@ bool HSigmoidGradCPUKernel::Launch(const std::vector &inp } } }; - CPUKernelUtils::ParallelFor(task, tensor_size_); + ParallelLaunchAutoSearch(task, tensor_size_, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc index 97ebb05023f..54f7f3ea3a2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_cpu_kernel.cc @@ -57,7 +57,7 @@ bool HSwishCPUKernel::Launch(const std::vector &inputs, c } } }; - CPUKernelUtils::ParallelFor(task, tensor_size_); + ParallelLaunchAutoSearch(task, tensor_size_, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc index 16c4a2ac434..09571b721d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/hswish_grad_cpu_kernel.cc @@ -61,7 +61,7 @@ bool HSwishGradCPUKernel::Launch(const std::vector &input } } }; - CPUKernelUtils::ParallelFor(task, tensor_size_); + ParallelLaunchAutoSearch(task, tensor_size_, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/iou_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/iou_cpu_kernel.cc index cf5874b12e7..89a147bf1da 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/iou_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/iou_cpu_kernel.cc @@ -93,7 +93,7 @@ bool IOUCPUKernel::Launch(const std::vector &inputs, cons } } }; - CPUKernelUtils::ParallelFor(task, iou_size_); + ParallelLaunchAutoSearch(task, iou_size_, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc index c203093b52e..78e38f1898d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2_normalize_cpu_kernel.cc @@ -90,7 +90,7 @@ void L2NormalizeCPUKernel::CalcDenominator(const T *input_addr, const size_t (*denominator_addr)[i] = sqrt(denominator); } }; - CPUKernelUtils::ParallelFor(task, reduce_size); + ParallelLaunchAutoSearch(task, reduce_size, this, ¶llel_search_info_); } template @@ -120,7 +120,7 @@ void L2NormalizeCPUKernel::CalcOutput(const T *input_addr, const std::vector< iter.GenNextPos(); } }; - CPUKernelUtils::ParallelFor(task, output_size); + ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_); } template diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc index b35ad9eb527..1db4e146a0c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc @@ -65,7 +65,7 @@ bool L2NormalizeGradCPUKernel::Launch(const std::vector &inputs, GetOutput(input_x_vector, y_vector, dout_vector, high_dim_index, &output[i]); } }; - CPUKernelUtils::ParallelFor(task, output_size); + ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc index 923f9c87721..e93e5fd8cff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc @@ -86,11 +86,11 @@ bool AddNCPUKernel::Launch(const std::vector &inputs, const const auto input_1 = reinterpret_cast(inputs[1]->addr); auto output = reinterpret_cast(outputs[0]->addr); auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2); - CPUKernelUtils::ParallelFor(task_0, elements_num); + ParallelLaunchAutoSearch(task_0, elements_num, this, ¶llel_search_info_); for (size_t index = 2; index < input_num_; ++index) { const auto input = reinterpret_cast(inputs[index]->addr); auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2); - CPUKernelUtils::ParallelFor(task, elements_num); + ParallelLaunchAutoSearch(task, elements_num, this, ¶llel_search_info_); } } else if (dtype_ == kNumberTypeFloat64) { size_t elements_num = outputs[0]->size / sizeof(double); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nms_with_mask_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/nms_with_mask_cpu_kernel.cc index f91f7447c1c..ce5a63fa5c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nms_with_mask_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nms_with_mask_cpu_kernel.cc @@ -49,7 +49,7 @@ void NMSWithMaskCPUKernel::NmsBitonicSortByKeyKernel(const int inner, const s index_buff[i] = i; } }; - CPUKernelUtils::ParallelFor(task1, ceil_power2); + ParallelLaunchAutoSearch(task1, ceil_power2, this, ¶llel_search_info_); for (size_t i = 2; i <= ceil_power2; i <<= 1) { for (size_t j = (i >> 1); j > 0; j >>= 1) { @@ -71,7 +71,7 @@ void NMSWithMaskCPUKernel::NmsBitonicSortByKeyKernel(const int inner, const s } } }; - CPUKernelUtils::ParallelFor(task2, ceil_power2); + ParallelLaunchAutoSearch(task2, ceil_power2, this, ¶llel_search_info_); } } } @@ -84,7 +84,7 @@ void NMSWithMaskCPUKernel::MaskInit(size_t numSq, bool *row_mask) { row_mask[mat_pos] = true; } }; - CPUKernelUtils::ParallelFor(task, numSq); + ParallelLaunchAutoSearch(task, numSq, this, ¶llel_search_info_); } // copy data from input to output array sorted by indices returned from bitonic sort @@ -122,7 +122,7 @@ void NMSWithMaskCPUKernel::PopulateOutput(const T *data_in, T *data_out, cons } } }; - CPUKernelUtils::ParallelFor(task, IntToSize(num)); + ParallelLaunchAutoSearch(task, IntToSize(num), this, ¶llel_search_info_); } // populated return mask (init to all true) and return index array @@ -134,7 +134,7 @@ void NMSWithMaskCPUKernel::Preprocess(const int num, int *sel_idx, bool *sel_ sel_boxes[box_num] = true; } }; - CPUKernelUtils::ParallelFor(task, IntToSize(num)); + ParallelLaunchAutoSearch(task, IntToSize(num), this, ¶llel_search_info_); } template @@ -175,7 +175,7 @@ void NMSWithMaskCPUKernel::NmsPass(const int num, const float IOU_value, cons } } }; - CPUKernelUtils::ParallelFor(task, IntToSize(num * num)); + ParallelLaunchAutoSearch(task, IntToSize(num * num), this, ¶llel_search_info_); } // Reduce pass runs on 1 block to allow thread sync @@ -192,7 +192,7 @@ void NMSWithMaskCPUKernel::ReducePass(const int num, bool *sel_boxes, const b sel_boxes[j] = sel_boxes[j] && row_mask[i * num + j]; } }; - CPUKernelUtils::ParallelFor(task, IntToSize(num)); + ParallelLaunchAutoSearch(task, IntToSize(num), this, ¶llel_search_info_); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc index 0c258e339a9..ea987a573d0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc @@ -61,28 +61,8 @@ bool PackCpuFwdKernel::Launch(const std::vector &inputs, const st // multi-threading size_t input_size = output_size_; - size_t max_thread_num = std::max(std::thread::hardware_concurrency(), static_cast(1)); - size_t use_thread_num = - input_size < 128 * max_thread_num ? std::ceil(static_cast(input_size / 128.0)) : max_thread_num; - std::vector threads; - - if (use_thread_num < 1) { - use_thread_num = 1; - } - - threads.reserve(use_thread_num); - size_t start = 0; - size_t batch_size = (input_size + use_thread_num - 1) / use_thread_num; - - while (start < input_size) { - size_t end = (start + batch_size) > input_size ? input_size : (start + batch_size); - (void)threads.emplace_back(std::thread(&PackCpuFwdKernel::PackTensor, this, output, start, end)); - start += batch_size; - } - - for (auto &it : threads) { - it.join(); - } + auto task = [this, &output](size_t start, size_t end) { PackTensor(output, start, end); }; + ParallelLaunchAutoSearch(task, input_size, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc index b9f00a6a661..f364cf22750 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc @@ -34,35 +34,16 @@ void StandardNormal(float *output, std::normal_distribution distribution, } } -void LaunchStandardNormal(unsigned int seed, const std::vector &outputs) { +void LaunchStandardNormal(RandomCPUKernel *content, unsigned int seed, const std::vector &outputs) { auto output = reinterpret_cast(outputs[0]->addr); // multithreading size_t lens = outputs[0]->size / sizeof(float); - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - std::vector threads; - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - std::normal_distribution distribution; - while (start < lens) { - // avoid different threads using the same seed to generate the same random number + auto task = [&seed, &output](size_t start, size_t end) { + std::normal_distribution distribution; std::default_random_engine random_generator(++seed); - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - (void)threads.emplace_back(std::thread(StandardNormal, output, distribution, random_generator, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } + StandardNormal(output, distribution, random_generator, start, end); + }; + ParallelLaunchAutoSearch(task, lens, content, &content->parallel_search_info_); } void LaunchUniformInt(unsigned int seed, const std::vector &inputs, @@ -138,7 +119,7 @@ bool RandomCPUKernel::Launch(const std::vector &inputs, cons if (random_op_type_ == RANDOM_OP_NORMAL) { CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStandardNormalOutputsNum, kernel_name_); - LaunchStandardNormal(RNG_seed, outputs); + LaunchStandardNormal(this, RNG_seed, outputs); } else if (random_op_type_ == RANDOM_OP_UNIFORM_INT) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kUniformIntInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kUniformIntOutputsNum, kernel_name_); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc index fb2b94ccf0e..7eca5180c00 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc @@ -176,7 +176,7 @@ bool ReduceCPUKernel::Launch(const std::vector &inputs, c } } }; - CPUKernelUtils::ParallelFor(task, output_size); + ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_); } return true; } @@ -204,7 +204,7 @@ void ReduceCPUKernel::AccelerateLongVector(T *input_addr, T *output_addr, siz reduce_func_(&block_output, 0, output_addr); } }; - CPUKernelUtils::ParallelFor(task, input_size); + ParallelLaunchAutoSearch(task, input_size, this, ¶llel_search_info_); if (reduce_type_ == kReduceMean) { *output_addr /= input_size; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h index 973e4e63345..1c998168f0e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_append_cpu_kernel.h @@ -88,7 +88,7 @@ class BufferCPUAppendKernel : public CPUKernel { } } }; - CPUKernelUtils::ParallelFor(task, element_nums_); + ParallelLaunchAutoSearch(task, element_nums_, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h index 7f7137c3602..aec9b7665ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_get_cpu_kernel.h @@ -77,7 +77,7 @@ class BufferCPUGetKernel : public CPUKernel { } } }; - CPUKernelUtils::ParallelFor(task, element_nums_); + ParallelLaunchAutoSearch(task, element_nums_, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h index a80bd4be1df..bca2b885d9b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rl/buffer_sample_cpu_kernel.h @@ -104,7 +104,7 @@ class BufferCPUSampleKernel : public CPUKernel { } } }; - CPUKernelUtils::ParallelFor(task, batch_size_); + ParallelLaunchAutoSearch(task, batch_size_, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc index fe72e60db34..fa9ffac1397 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/rmsprop_cpu_kernel.cc @@ -44,7 +44,7 @@ void RMSPropCPUKernel::LaunchRMSPropUnuseCenter(T *variable, T *mean_square, } }; } - CPUKernelUtils::ParallelFor(task, size_); + ParallelLaunchAutoSearch(task, size_, this, ¶llel_search_info_); } template @@ -70,7 +70,7 @@ void RMSPropCPUKernel::LaunchRMSPropUseCenter(T *variable, T *mean_square, T } }; } - CPUKernelUtils::ParallelFor(task, size_); + ParallelLaunchAutoSearch(task, size_, this, ¶llel_search_info_); } template diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_cpu_kernel.cc index 41a9cbf4a34..7b98678f9bb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_cpu_kernel.cc @@ -109,8 +109,7 @@ bool ROIAlignCPUKernel::Launch(const std::vector &inputs, out_data[thread_idx] = accumulate_val; } }; - CPUKernelUtils::ParallelFor(task, elem_num); - + ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_grad_cpu_kernel.cc index 8d3384101a7..801ee75d213 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_grad_cpu_kernel.cc @@ -117,7 +117,7 @@ bool ROIAlignGradCPUKernel::Launch(const std::vector &inp dx[thread_idx] = ZERO; } }; - CPUKernelUtils::ParallelFor(task1, IntToSize(size_init)); + ParallelLaunchAutoSearch(task1, IntToSize(size_init), this, ¶llel_search_info_); int elem_num = roi_rows_ * channels_ * pooled_height_ * pooled_width_; auto task2 = [this, &dy, &rois, &dx](size_t start, size_t end) { @@ -176,7 +176,7 @@ bool ROIAlignGradCPUKernel::Launch(const std::vector &inp } } }; - CPUKernelUtils::ParallelFor(task2, IntToSize(elem_num)); + ParallelLaunchAutoSearch(task2, IntToSize(elem_num), this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_cpu_kernel.cc index cf306c401b0..d1c55e2bb38 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_cpu_kernel.cc @@ -27,7 +27,8 @@ constexpr size_t kScatterNdOutputSize = 1; constexpr size_t kMinIndiceRank = 2; template -void Compute(const ComputeParams *params, const size_t start, const size_t end) { +void Compute(ScatterNdCPUKernel *content, const ComputeParams *params, const size_t start, + const size_t end) { T *target = params->target_; S *indices = params->indices_; T *updates = params->updates_; @@ -47,7 +48,7 @@ void Compute(const ComputeParams *params, const size_t start, const size_t target[IntToSize(offset) + idx] += updates[IntToSize(params->unit_size_) * i + idx]; } }; - CPUKernelUtils::ParallelFor(task, IntToSize(params->unit_size_)); + ParallelLaunchAutoSearch(task, IntToSize(params->unit_size_), content, &content->parallel_search_info_); } } } // namespace @@ -113,10 +114,10 @@ bool ScatterNdCPUKernel::Launch(const std::vector &inp auto task = [this, ¶ms](size_t start, size_t end) { for (size_t idx = start; idx < end; idx++) { - Compute(¶ms, idx, idx + 1); + Compute(this, ¶ms, idx, idx + 1); } }; - CPUKernelUtils::ParallelFor(task, num_units_); + ParallelLaunchAutoSearch(task, num_units_, this, ¶llel_search_info_); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc index 8ecd168a0d6..52b6669170b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/searchsorted_cpu_kernel.cc @@ -70,7 +70,7 @@ bool SearchSortedCPUKernel::Launch(const std::vector & output[i] = static_cast(result); } }; - CPUKernelUtils::ParallelFor(task, elem_num); + ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); return true; } @@ -96,7 +96,7 @@ void SearchSortedCPUKernel::CheckParam(const std::vector &inpu } } }; - CPUKernelUtils::ParallelFor(task, IntToSize(list_count)); + ParallelLaunchAutoSearch(task, IntToSize(list_count), this, ¶llel_search_info_); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc index 2c108c45785..319cf750d21 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sgd_cpu_kernel.cc @@ -72,7 +72,7 @@ bool SGDCPUKernel::Launch(const std::vector &inputs, const std::v output_param[i] = param[i]; } }; - CPUKernelUtils::ParallelFor(task, elem_num); + ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc index 504206707a9..8bae38b89a9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/smooth_l1_loss_cpu_kernel.cc @@ -63,7 +63,7 @@ bool SmoothL1LossCPUKernel::Launch(const std::vector &inp } } }; - CPUKernelUtils::ParallelFor(task, tensor_size_); + ParallelLaunchAutoSearch(task, tensor_size_, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc index fc88bff4e9f..1525aca0b25 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc @@ -84,7 +84,7 @@ bool SpaceToDepthCPUKernel::Launch(const std::vector &inp } }; - CPUKernelUtils::ParallelFor(task, size); + ParallelLaunchAutoSearch(task, size, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc index 6f4cc06fd29..53a5b169647 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc @@ -23,23 +23,22 @@ namespace runtime { void ComputeThreadNums(size_t *actor_thread_num, size_t *actor_and_kernel_thread_num) { MS_EXCEPTION_IF_NULL(actor_thread_num); MS_EXCEPTION_IF_NULL(actor_and_kernel_thread_num); - size_t cpu_core_num = std::thread::hardware_concurrency() - 1; - - // Compute the actor thread num. - const size_t kActorThreadMaxNum = 5; + const size_t cpu_core_num = std::thread::hardware_concurrency() - 1; // The MemoryManagerActor binds single thread, and the other actors share one thread at least, so the min num is 2. const size_t kActorThreadMinNum = 2; + // Compute the actor thread num. + const size_t kActorThreadMaxNum = 5; + // a machine may run multiple process, a process should not use all CPUs. default run 5 process in the same time. + const size_t kParallelNum = 5; + const size_t kThreadMaxNum = cpu_core_num / kParallelNum; + auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); *actor_thread_num = cpu_core_num < kActorThreadMinNum ? kActorThreadMinNum : cpu_core_num; *actor_thread_num = *actor_thread_num > kActorThreadMaxNum ? kActorThreadMaxNum : *actor_thread_num; - // Compute the actor and kernel thread num. - const size_t kActorAndKernelThreadMaxNum = 23; - *actor_and_kernel_thread_num = cpu_core_num > *actor_thread_num ? cpu_core_num : (*actor_thread_num + 1); - *actor_and_kernel_thread_num = *actor_and_kernel_thread_num > kActorAndKernelThreadMaxNum - ? kActorAndKernelThreadMaxNum - : *actor_and_kernel_thread_num; + // Compute the actor and kernel thread num. 1 thread is useless for kernel, so kernel thread num should at least 2. + *actor_and_kernel_thread_num = kThreadMaxNum > (*actor_thread_num + 2) ? kThreadMaxNum : (*actor_thread_num + 2); } bool IsDeviceQueueDSActor(const AnfNodePtr &node, GraphExecutionStrategy strategy) {