change omp to ms threadpool

This commit is contained in:
fangzehua 2021-12-28 11:46:48 +08:00
parent 214271459d
commit ac11ceeee4
46 changed files with 219 additions and 231 deletions

View File

@ -1,5 +1,10 @@
set(onednn_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(onednn_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
if(USE_MS_THREADPOOL_FOR_DNNL)
set(USE_MS_THREADPOOL "-DDNNL_CPU_RUNTIME=THREADPOOL")
else()
set(USE_MS_THREADPOOL "")
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
mindspore_add_pkg(onednn
VER 2.2
@ -22,7 +27,7 @@ else()
URL ${REQ_URL}
MD5 ${MD5}
CMAKE_OPTION -DDNNL_ARCH_OPT_FLAGS='' -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF
-DDNNL_ENABLE_CONCURRENT_EXEC=ON)
${USE_MS_THREADPOOL} -DDNNL_ENABLE_CONCURRENT_EXEC=ON)
endif()
include_directories(${onednn_INC})

View File

@ -28,6 +28,11 @@ option(ENABLE_SYM_FILE "enable sym file" OFF)
option(BUILD_DEV_MODE "MindSpore build nightly dev mode" OFF)
option(ENABLE_FAST_HASH_TABLE "Enable use fast hash table instead of std ones" ON)
option(USE_LLVM "use llvm" OFF)
option(USE_MS_THREADPOOL_FOR_DNNL "use ms threadpool for onednn ops" ON)
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
set(USE_MS_THREADPOOL_FOR_DNNL OFF)
endif()
if(NOT ENABLE_D AND NOT ENABLE_TESTCASES AND NOT ENABLE_ACL)
set(ENABLE_GLIBCXX ON)
@ -161,3 +166,7 @@ endif()
if(USE_LLVM)
add_compile_definitions(USE_LLVM)
endif()
if(USE_MS_THREADPOOL_FOR_DNNL)
add_compile_definitions(USE_MS_THREADPOOL_FOR_DNNL)
endif()

View File

@ -18,6 +18,14 @@ if(ENABLE_D)
add_subdirectory(backend/kernel_compiler/aicpu/aicpu_ops)
endif()
# add openmp
find_package(OpenMP)
if(OPENMP_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
else()
message(FATAL_ERROR "OpenMP not found")
endif()
if(ENABLE_CPU)
if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64")
set(PLATFORM_ARM64 "on")

View File

@ -26,6 +26,7 @@ namespace {
constexpr size_t kAdamInputsNum = 10;
constexpr size_t kAdamOutputsNum = 3;
constexpr size_t kScalarIndex = 0;
constexpr float kAdamBlock = 1000;
} // namespace
template <typename T>
@ -60,7 +61,7 @@ void AdamCPUKernel::LaunchAdam(const std::vector<kernel::AddressPtr> &inputs, co
}
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
ParallelLaunch(task, lens, kAdamBlock, this);
}
void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &inputs,
@ -89,7 +90,7 @@ void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &input
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', AdamFp32 failed. Error no: " << ret;
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
ParallelLaunch(task, lens, kAdamBlock, this);
}
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {

View File

@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"

View File

@ -31,9 +31,6 @@
#include "runtime/framework/graph_scheduler.h"
#include "actor/actormgr.h"
#include "common/thread_pool.h"
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)
#define PLATFORM_86
#endif
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;

View File

@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h"
#include "backend/kernel_compiler/cpu/nnacl/errorcode.h"
@ -65,7 +64,7 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::binary::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);

View File

@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -47,7 +46,7 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src0_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_desc = GetDefaultMemDesc(src1_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::binary::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_desc);
AddArgument(DNNL_ARG_SRC_1, src1_desc);

View File

@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -62,7 +61,7 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
dnnl::batch_normalization_forward::desc desc =
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, x_desc);
AddArgument(DNNL_ARG_MEAN, prim_desc.mean_desc());

View File

@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -67,13 +66,13 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
// fused Batch Normalization forward description
dnnl::batch_normalization_forward::desc desc =
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags);
auto forward_prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto forward_prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_);
// fused Batch Normalization backward description
dnnl::batch_normalization_backward::desc backward_desc =
dnnl::batch_normalization_backward::desc(dnnl::prop_kind::backward, x_desc, x_desc, epsilon, normalization_flags);
auto backward_prim_desc = dnnl::batch_normalization_backward::primitive_desc(
backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc);
auto backward_prim_desc =
dnnl::batch_normalization_backward::primitive_desc(backward_desc, engine_, forward_prim_desc);
primitive_ = std::make_shared<dnnl::batch_normalization_backward>(backward_prim_desc);
AddArgument(DNNL_ARG_SRC, x_desc);
AddArgument(DNNL_ARG_MEAN, forward_prim_desc.mean_desc());

View File

@ -18,7 +18,6 @@
#include <string>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
@ -86,13 +85,13 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine());
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_);
dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc(
dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto backward_prim_desc = dnnl::convolution_backward_weights::primitive_desc(
backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc);
auto backward_prim_desc =
dnnl::convolution_backward_weights::primitive_desc(backward_desc, engine_, forward_prim_desc);
primitive_ = std::make_shared<dnnl::convolution_backward_weights>(backward_prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);

View File

@ -18,7 +18,6 @@
#include <string>
#include <map>
#include <algorithm>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -99,13 +98,12 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine());
auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_);
dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc(
dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto backward_prim_desc =
dnnl::convolution_backward_data::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc);
auto backward_prim_desc = dnnl::convolution_backward_data::primitive_desc(backward_desc, engine_, forward_prim_desc);
primitive_ = std::make_shared<dnnl::convolution_backward_data>(backward_prim_desc);
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);

View File

@ -18,7 +18,6 @@
#include <string>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
@ -107,7 +106,7 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::convolution_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_WEIGHTS, weights_desc);

View File

@ -17,7 +17,6 @@
#include "backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h"
#include <string>
#include <unordered_map>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -65,7 +64,7 @@ void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
auto desc = GetForwardEltwiseDesc(src_desc);
auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::eltwise_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);

View File

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.h"
#include <algorithm>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -41,7 +40,7 @@ void LogSoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::logsoftmax_forward::desc desc =
dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::logsoftmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);

View File

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.h"
#include <algorithm>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -41,11 +40,10 @@ void LogSoftmaxGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::logsoftmax_forward::desc desc =
dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_);
// backward description
dnnl::logsoftmax_backward::desc backward_desc = dnnl::logsoftmax_backward::desc(src_desc, src_desc, axis);
auto backward_prim_desc =
dnnl::logsoftmax_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), prim_desc);
auto backward_prim_desc = dnnl::logsoftmax_backward::primitive_desc(backward_desc, engine_, prim_desc);
primitive_ = std::make_shared<dnnl::logsoftmax_backward>(backward_prim_desc);
AddArgument(DNNL_ARG_DST, src_desc);
AddArgument(DNNL_ARG_DIFF_SRC, src_desc);

View File

@ -17,12 +17,7 @@
#include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h"
#include <string>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#ifdef PLATFORM_86
#include <pmmintrin.h>
#endif
namespace mindspore {
namespace kernel {
namespace {
@ -54,14 +49,10 @@ void LstmCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
}
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
#ifdef PLATFORM_86
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
#endif
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
CheckParam(kernel_node);
auto eng = MKLKernelEngine::Get().engine();
auto eng = engine_;
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat;
@ -153,7 +144,7 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const
const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag;
auto eng = MKLKernelEngine::Get().engine();
auto eng = engine_;
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng);

View File

@ -18,7 +18,6 @@
#include <cstring>
#include <string>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
@ -43,7 +42,7 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
CheckParam(kernel_node);
auto eng = MKLKernelEngine::Get().engine();
auto eng = engine_;
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat;
@ -178,7 +177,7 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, co
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstmGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstmGradOutputsNum, kernel_name_);
auto eng = MKLKernelEngine::Get().engine();
auto eng = engine_;
// construct fw memory
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);

View File

@ -17,7 +17,6 @@
#include "backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h"
#include <utility>
#include "common/thread_pool.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
#include "backend/kernel_compiler/cpu/nnacl/matmul_parameter.h"
#include "backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h"
@ -81,7 +80,7 @@ void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc weights_md(weights_dims, dnnl::memory::data_type::f32, b_strides);
dnnl::memory::desc dst_md(dst_dims, dnnl::memory::data_type::f32, o_strides);
dnnl::matmul::desc matmul_desc(src_md, weights_md, dst_md);
dnnl::matmul::primitive_desc prim_desc(matmul_desc, MKLKernelEngine::Get().engine());
dnnl::matmul::primitive_desc prim_desc(matmul_desc, engine_);
primitive_ = std::make_shared<dnnl::matmul>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_md);

View File

@ -19,7 +19,7 @@
#include <string>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "utils/profile.h"
namespace mindspore {
namespace kernel {
@ -139,7 +139,11 @@ dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &sh
}
void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc) {
arguments_[arg_key] = MKLKernelEngine::Get().CreateMemory(mem_desc, alloc);
if (alloc) {
arguments_[arg_key] = dnnl::memory(mem_desc, engine_);
} else {
arguments_[arg_key] = dnnl::memory(mem_desc, engine_, nullptr);
}
}
void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) {
@ -149,10 +153,47 @@ void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) {
}
}
void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); }
void MKLCPUKernel::ExecutePrimitive() {
MS_EXCEPTION_IF_NULL(primitive_);
#ifdef USE_MS_THREADPOOL_FOR_DNNL
// add auto search
const size_t MAX_POW = 6;
const size_t AVG_COUNT = 5;
const size_t DIFF = 2;
size_t current_pow = parallel_search_info_.search_count / AVG_COUNT;
int current_thread_nums = static_cast<int>(std::pow(2.0f, current_pow));
auto mkl_pool = dynamic_cast<mkl_threadpool *>(mkl_threadpool_.get());
if (current_pow < MAX_POW) {
if (parallel_search_info_.search_count % AVG_COUNT == 0) {
parallel_search_info_.tmp_sum_cost_time = 0;
}
double start_time = GetTime();
mkl_pool->set_num_threads(current_thread_nums);
primitive_->execute(stream_, arguments_);
double cost_time = GetTime() - start_time;
parallel_search_info_.tmp_sum_cost_time += cost_time;
parallel_search_info_.search_count++;
if (parallel_search_info_.search_count % AVG_COUNT == 0) {
if (parallel_search_info_.min_cost_time > parallel_search_info_.tmp_sum_cost_time) {
parallel_search_info_.min_cost_time = parallel_search_info_.tmp_sum_cost_time;
parallel_search_info_.best_pow = current_pow;
} else if (current_pow - parallel_search_info_.best_pow >= DIFF) {
parallel_search_info_.search_count = AVG_COUNT * MAX_POW;
}
}
} else {
int best_thread_nums = static_cast<int>(std::pow(2.0f, parallel_search_info_.best_pow));
mkl_pool->set_num_threads(best_thread_nums);
primitive_->execute(stream_, arguments_);
}
#else
primitive_->execute(stream_, arguments_);
#endif
(void)stream_.wait();
}
void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
MKLKernelEngine::Get().Reorder(src_mem, dst_mem);
dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem);
}
} // namespace kernel
} // namespace mindspore

View File

@ -19,17 +19,54 @@
#include <string>
#include <unordered_map>
#include <algorithm>
#include <memory>
#include <vector>
#include "dnnl.hpp"
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#ifdef USE_MS_THREADPOOL_FOR_DNNL
#include "dnnl_threadpool.hpp"
#include "dnnl_threadpool_iface.hpp"
#endif
namespace mindspore {
namespace kernel {
#ifdef USE_MS_THREADPOOL_FOR_DNNL
class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
private:
ActorThreadPool *tp_;
int thread_num_{8};
public:
explicit mkl_threadpool(ActorThreadPool *tp) { tp_ = tp; }
void set_num_threads(int num) { thread_num_ = num; }
int get_num_threads() const override { return std::min(SizeToInt(tp_->GetKernelThreadNum()), thread_num_); }
bool get_in_parallel() const override { return false; }
uint64_t get_flags() const override { return 0; }
void parallel_for(int n, const std::function<void(int, int)> &fn) override {
int nthr = get_num_threads();
int n_jobs = std::min(n, nthr);
auto func = [&, n_jobs](void *, int i, float, float) {
fn(i, n_jobs);
return 0;
};
tp_->ParallelLaunch(func, nullptr, n_jobs);
}
};
#endif
class MKLCPUKernel : public CPUKernel {
public:
MKLCPUKernel() = default;
#ifdef USE_MS_THREADPOOL_FOR_DNNL
MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0) {
auto thread_pool = GetActorMgrInnerThreadPool();
mkl_threadpool_ = std::make_shared<mkl_threadpool>(thread_pool);
stream_ = dnnl::threadpool_interop::make_stream(engine_, mkl_threadpool_.get());
}
#else
MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {}
#endif
~MKLCPUKernel() override = default;
protected:
@ -50,6 +87,11 @@ class MKLCPUKernel : public CPUKernel {
std::unordered_map<int, dnnl::memory> arguments_;
std::shared_ptr<dnnl::primitive> primitive_{nullptr};
dnnl::engine engine_;
dnnl::stream stream_;
#ifdef USE_MS_THREADPOOL_FOR_DNNL
std::shared_ptr<dnnl::threadpool_interop::threadpool_iface> mkl_threadpool_{nullptr};
#endif
};
} // namespace kernel
} // namespace mindspore

View File

@ -1,42 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "utils/log_adapter.h"
#include "dnnl.hpp"
namespace mindspore {
namespace kernel {
void MKLKernelEngine::Execute(const std::shared_ptr<dnnl::primitive> &primitive,
const std::unordered_map<int, dnnl::memory> &arguments) {
MS_EXCEPTION_IF_NULL(primitive);
primitive->execute(stream_, arguments);
(void)stream_.wait();
}
dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc) {
if (alloc) {
return dnnl::memory(mem_desc, engine_);
} else {
return dnnl::memory(mem_desc, engine_, nullptr);
}
}
void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem);
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,58 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_KERNEL_ENGINE_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_KERNEL_ENGINE_H_
#include <cstdlib>
#include <algorithm>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
#include <memory>
#include "dnnl.hpp"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
class MKLKernelEngine {
public:
static MKLKernelEngine &Get() {
static MKLKernelEngine instance;
return instance;
}
DISABLE_COPY_AND_ASSIGN(MKLKernelEngine)
const dnnl::engine &engine() const { return engine_; }
dnnl::memory CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc = false);
void Execute(const std::shared_ptr<dnnl::primitive> &primitive,
const std::unordered_map<int, dnnl::memory> &arguments);
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);
private:
MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {}
~MKLKernelEngine() = default;
dnnl::engine engine_;
dnnl::stream stream_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_KERNEL_ENGINE_H_

View File

@ -19,7 +19,6 @@
#include <utility>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
@ -69,13 +68,12 @@ void AvgPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::pooling_forward::desc desc =
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, dst_desc,
strides_dims, kernels_dims, padding_l, padding_r);
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_);
// pooling_avg backward description
dnnl::pooling_backward::desc backward_desc = dnnl::pooling_backward::desc(
dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
auto backward_prim_desc =
dnnl::pooling_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), prim_desc);
auto backward_prim_desc = dnnl::pooling_backward::primitive_desc(backward_desc, engine_, prim_desc);
primitive_ = std::make_shared<dnnl::pooling_backward>(backward_prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);

View File

@ -17,7 +17,6 @@
#include <string>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
@ -84,7 +83,7 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) {
desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
}
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_);
workspace_size_ = prim_desc.workspace_desc().get_size();
primitive_ = std::make_shared<dnnl::pooling_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);

View File

@ -19,7 +19,6 @@
#include <utility>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {

View File

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h"
#include <algorithm>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -47,7 +46,7 @@ void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);

View File

@ -18,7 +18,6 @@
#include <numeric>
#include <limits>
#include <functional>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -56,7 +55,7 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_n
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, mem_desc);

View File

@ -19,7 +19,6 @@
#include <limits>
#include <functional>
#include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@ -62,7 +61,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_);
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, mem_desc);

View File

@ -26,6 +26,7 @@ constexpr size_t kUniformRealInputsNum = 1;
constexpr size_t kUniformIntOutputsNum = 1;
constexpr size_t kUniformRealOutputsNum = 1;
constexpr size_t kStandardNormalOutputsNum = 1;
constexpr float kRandomBlockSize = 128.0;
constexpr char kKernelName[] = "Random";
} // namespace
void StandardNormal(float *output, std::normal_distribution<float> distribution,
@ -39,12 +40,12 @@ void LaunchStandardNormal(RandomCPUKernel *content, unsigned int seed, const std
auto output = reinterpret_cast<float *>(outputs[0]->addr);
// multithreading
size_t lens = outputs[0]->size / sizeof(float);
auto task = [&seed, &output](size_t start, size_t end) {
std::default_random_engine random_generator(++seed);
auto task = [&seed, &output, &random_generator](size_t start, size_t end) {
std::normal_distribution<float> distribution;
std::default_random_engine random_generator(++seed);
StandardNormal(output, distribution, random_generator, start, end);
};
ParallelLaunchAutoSearch(task, lens, content, &content->parallel_search_info_);
ParallelLaunch(task, lens, kRandomBlockSize, content);
}
void LaunchUniformInt(unsigned int seed, const std::vector<AddressPtr> &inputs,

View File

@ -1135,6 +1135,18 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
SetInternalOutputAttr(new_node);
}
void KernelGraph::EnableRuntimeCache() {
auto node_list = TopoSort(get_return());
for (auto &node : node_list) {
auto kernel_info = node->kernel_info();
if (!kernel_info) {
continue;
}
auto runtime_cache = kernel_info->runtime_cache();
runtime_cache.runtime_cache().set_valid();
}
}
void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
size_t dst_output_idx) {
if (new_node == nullptr || node == nullptr) {

View File

@ -286,6 +286,7 @@ class KernelGraph : public FuncGraph {
}
}
void RemoveNodeFromGraph(const AnfNodePtr &node);
void EnableRuntimeCache();
void UpdateGraphDynamicAttr();
void SetGraphDynamicAttr(bool is_dynamic_shape) { is_dynamic_shape_ = is_dynamic_shape; }
bool is_dynamic_shape() const { return is_dynamic_shape_; }

View File

@ -87,6 +87,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("mempool_block_size", MsCtxParam::MS_CTX_MEMPOOL_BLOCK_SIZE)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("runtime_num_threads", MsCtxParam::MS_CTX_RUNTIME_NUM_THREADS)
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)

View File

@ -25,23 +25,27 @@ bool ActorDispatcher::is_multi_thread_execution_ = true;
void ComputeThreadNums(size_t *actor_thread_num, size_t *actor_and_kernel_thread_num) {
MS_EXCEPTION_IF_NULL(actor_thread_num);
MS_EXCEPTION_IF_NULL(actor_and_kernel_thread_num);
const size_t cpu_core_num = std::thread::hardware_concurrency() - 1;
// Compute the actor thread num.
const size_t kActorThreadMaxNum = 5;
// The MemoryManagerActor binds single thread, and the other actors share one thread at least, so the min num is 2.
const size_t kActorThreadMinNum = 2;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
*actor_thread_num = cpu_core_num < kActorThreadMinNum ? kActorThreadMinNum : cpu_core_num;
*actor_thread_num = *actor_thread_num > kActorThreadMaxNum ? kActorThreadMaxNum : *actor_thread_num;
const size_t cpu_core_num = std::thread::hardware_concurrency() - 1;
auto runtime_num_threads = static_cast<size_t>(context_ptr->get_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS));
size_t runtime_num_threads_min = std::min(runtime_num_threads, cpu_core_num);
const float kActorUsage = 0.2;
const size_t kActorThreadMinNum = 2;
size_t actor_thread_max_num =
std::max(static_cast<size_t>(std::floor(runtime_num_threads_min * kActorUsage)), kActorThreadMinNum);
// Compute the actor thread num.
// The MemoryManagerActor binds single thread, and the other actors share one thread at least, so the min num is 2.
*actor_thread_num = runtime_num_threads_min < kActorThreadMinNum ? kActorThreadMinNum : runtime_num_threads_min;
*actor_thread_num = *actor_thread_num > actor_thread_max_num ? actor_thread_max_num : *actor_thread_num;
// Compute the actor and kernel thread num.
const size_t kKernelThreadMaxNum = 23;
const size_t kActorAndKernelThreadMaxNum = kKernelThreadMaxNum + *actor_thread_num;
*actor_and_kernel_thread_num = cpu_core_num > *actor_thread_num ? cpu_core_num : (*actor_thread_num + 1);
*actor_and_kernel_thread_num = *actor_and_kernel_thread_num > kActorAndKernelThreadMaxNum
? kActorAndKernelThreadMaxNum
: *actor_and_kernel_thread_num;
*actor_and_kernel_thread_num =
runtime_num_threads_min > *actor_thread_num ? runtime_num_threads_min : (*actor_thread_num + 1);
if (runtime_num_threads != *actor_and_kernel_thread_num) {
MS_LOG(WARNING) << "The runtime_num_threads is " << runtime_num_threads
<< ", but actually the num of threads in threadpool is " << *actor_and_kernel_thread_num;
}
}
bool IsDeviceQueueDSActor(const AnfNodePtr &node, GraphExecutionStrategy strategy) {

View File

@ -492,7 +492,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
}
#endif
device_context->EnableRuntimeCache(graph);
graph->EnableRuntimeCache();
return graph->graph_id();
}

View File

@ -265,10 +265,8 @@ void GraphScheduler::Initialize() {
MS_LOG(EXCEPTION) << "Actor manager init failed.";
}
common::SetOMPThreadNum();
auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
<< ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num)
<< ", the used OMP thread number: " << OMP_thread_num_used;
<< ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num);
BuildAndScheduleGlobalActor();
}

View File

@ -40,20 +40,12 @@
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#endif
#ifdef PLATFORM_86
#include <pmmintrin.h>
#endif
namespace mindspore {
namespace device {
namespace cpu {
using mindspore::kernel::KernelBuildInfo;
#ifdef PLATFORM_86
// Whether need set the flush zero mode in the kernel launch.
static bool flush_zero_mode_enable{false};
#endif
void CPUDeviceContext::Initialize() {
if (initialized_) {
return;
@ -229,14 +221,6 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
MS_LOG(EXCEPTION) << "Build cpu operator[" << node->fullname_with_scope() << "] failed";
}
#ifdef PLATFORM_86
// Some CPU kernels need set the flush zero mode to improve performance.
if (!flush_zero_mode_enable &&
(kOpNeedSetFlushZeroModeList.find(kernel_name) != kOpNeedSetFlushZeroModeList.end())) {
flush_zero_mode_enable = true;
}
#endif
cpu_kernel->Init(node);
cpu_kernel->InitDynamicKernel(node);
auto cpu_dynamic_kernel = cpu_kernel->DynamicKernel();
@ -285,14 +269,6 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
#ifdef PLATFORM_86
// Some CPU kernels need set the flush zero mode to improve performance.
if (flush_zero_mode_enable) {
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
}
#endif
// Some CPU kernels can't initialize kernel and launch kernel in different thread, so reinitialize the kernels before
// launch.
if (kOpNotSupportMultiThreadExecList.find(AnfAlgo::GetCNodeName(kernel)) != kOpNotSupportMultiThreadExecList.end()) {

View File

@ -157,19 +157,6 @@ class DeviceContext {
// Return collective communication object for caller to access
CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
void EnableRuntimeCache(const KernelGraphPtr &graph) const {
auto node_list = graph->TopoSort(graph->get_return());
for (auto &node : node_list) {
auto kernel_info = node->kernel_info();
if (!kernel_info) {
continue;
}
MS_EXCEPTION_IF_NULL(kernel_info);
auto runtime_cache = kernel_info->runtime_cache();
runtime_cache.runtime_cache().set_valid();
}
}
protected:
DeviceContextKey device_context_key_;

View File

@ -716,8 +716,6 @@ const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
kGetNextOpName};
const std::set<std::string> kOpNeedSetFlushZeroModeList = {kLSTMOpName, kLSTMGradOpName};
const std::set<std::string> kOpNotSupportMultiThreadExecList = {kAvgPoolOpName, kAvgPoolGradOpName, kMaxPoolOpName,
kBatchNorm, kBatchNormGradOpName};

View File

@ -33,6 +33,11 @@ void ActorWorker::RunWithSpin() {
#if !defined(__APPLE__) && !defined(SUPPORT_MSVC)
static std::atomic_int index = {0};
(void)pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str());
#endif
#ifdef PLATFORM_86
// Some CPU kernels need set the flush zero mode to improve performance.
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
#endif
while (alive_) {
// only run either local KernelTask or PoolQueue ActorTask

View File

@ -76,6 +76,11 @@ void Worker::Run() {
#if !defined(__APPLE__) && !defined(SUPPORT_MSVC)
static std::atomic_int index = {0};
(void)pthread_setname_np(pthread_self(), ("KernelThread_" + std::to_string(index++)).c_str());
#endif
#ifdef PLATFORM_86
// Some CPU kernels need set the flush zero mode to improve performance.
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
#endif
while (alive_) {
if (RunLocalKernelTask()) {

View File

@ -27,7 +27,10 @@
#include <functional>
#include "thread/threadlog.h"
#include "thread/core_affinity.h"
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)
#define PLATFORM_86
#include <pmmintrin.h>
#endif
namespace mindspore {
constexpr int kDefaultSpinCount = 300000;
constexpr int kMaxCount = 30000;

View File

@ -18,6 +18,7 @@
#include <thread>
#include <atomic>
#include <fstream>
#include <algorithm>
#include "ir/tensor.h"
#include "utils/ms_utils.h"
@ -97,6 +98,11 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
size_t cpu_core_num = std::thread::hardware_concurrency() - 1;
constexpr float kCpuUsage = 0.6;
uint32_t runtime_num_threads = std::max(static_cast<int>(std::floor(cpu_core_num * kCpuUsage)), 1);
set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, runtime_num_threads);
backend_policy_ = policy_map_[policy];
}

View File

@ -102,6 +102,7 @@ enum MsCtxParam : unsigned {
// parameter of type uint32
MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
MS_CTX_RUNTIME_NUM_THREADS,
MS_CTX_GE_REF,
MS_CTX_MAX_CALL_DEPTH,
MS_CTX_TSD_REF,

View File

@ -318,6 +318,12 @@ class _Context:
"have permission to read it.".format(env_config_path))
self.set_param(ms_ctx_param.env_config_path, env_config_path)
def set_runtime_num_threads(self, runtime_num_threads):
"""Check and set runtime_num_threads."""
if runtime_num_threads <= 0:
raise ValueError("The num of cpu thread must bigger than 0.")
self.set_param(ms_ctx_param.runtime_num_threads, runtime_num_threads)
setters = {
'mode': set_mode,
'save_graphs_path': set_save_graphs_path,
@ -331,7 +337,8 @@ class _Context:
'mempool_block_size': set_mempool_block_size,
'print_file_path': set_print_file_path,
'env_config_path': set_env_config_path,
'backend_policy': set_backend_policy
'backend_policy': set_backend_policy,
'runtime_num_threads': set_runtime_num_threads
}
@property
@ -607,7 +614,7 @@ def _check_target_specific_cfgs(device, arg_key):
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool,
env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int,
load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str,
backend_policy=str)
def set_context(**kwargs):
@ -679,6 +686,8 @@ def set_context(**kwargs):
| +------------------------------+----------------------------+
| | enable_compile_cache | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | runtime_num_threads | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | compile_cache_path | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | backend_policy | Ascend |
@ -822,6 +831,8 @@ def set_context(**kwargs):
backend_policy (str): Used to choose a backend. ("ge", "vm" or "ms").
Through context.set_context(backend_policy="ms")
Default: The value must be in ['ge', 'vm', 'ms'].
runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime,
which must bigger than 0.
Raises:
ValueError: If input key is not an attribute in context.
@ -851,6 +862,7 @@ def set_context(**kwargs):
>>> context.set_context(grad_for_scalar=True)
>>> context.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms")
>>> context.set_context(pynative_synchronize=True)
>>> context.set_context(runtime_num_threads=10)
"""
ctx = _context()
# set device target first

View File

@ -26,6 +26,7 @@ from collections import Counter
import numpy as np
from mindspore import log as logger
from mindspore import context
from mindspore.common.initializer import Zero
from .. import signature as sig
from .._utils import get_broadcast_shape, is_shape_unknown
@ -3304,6 +3305,11 @@ class StridedSlice(PrimitiveWithInfer):
shrink_axis_mask=0):
"""Initialize StridedSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
# auto parallel haven't support begin_mask and end_mask
if context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]:
begin_mask = 0
end_mask = 0
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
validator.check_non_negative_int(end_mask, 'end_mask', self.name)
validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)