From ac11ceeee4c961e0a115b4e2b18e6d8ddf105401 Mon Sep 17 00:00:00 2001 From: fangzehua Date: Tue, 28 Dec 2021 11:46:48 +0800 Subject: [PATCH] change omp to ms threadpool --- cmake/external_libs/mkl_dnn.cmake | 7 ++- cmake/options.cmake | 9 +++ mindspore/ccsrc/CMakeLists.txt | 8 +++ .../kernel_compiler/cpu/adam_cpu_kernel.cc | 5 +- .../cpu/apply_momentum_cpu_kernel.cc | 1 - .../backend/kernel_compiler/cpu/cpu_kernel.h | 3 - .../cpu/mkldnn/addn_cpu_kernel.cc | 3 +- .../cpu/mkldnn/assignadd_cpu_kernel.cc | 3 +- .../cpu/mkldnn/batch_norm_cpu_kernel.cc | 3 +- .../cpu/mkldnn/batch_norm_grad_cpu_kernel.cc | 7 +-- .../mkldnn/conv2d_grad_filter_cpu_kernel.cc | 7 +-- .../mkldnn/conv2d_grad_input_cpu_kernel.cc | 6 +- .../cpu/mkldnn/conv_cpu_kernel.cc | 3 +- .../cpu/mkldnn/eltwise_cpu_kernel.cc | 3 +- .../cpu/mkldnn/log_softmax_cpu_kernel.cc | 3 +- .../cpu/mkldnn/log_softmax_grad_cpu_kernel.cc | 6 +- .../cpu/mkldnn/lstm_cpu_kernel.cc | 13 +---- .../cpu/mkldnn/lstm_grad_cpu_kernel.cc | 5 +- .../cpu/mkldnn/matmul_cpu_kernel.cc | 3 +- .../cpu/mkldnn/mkl_cpu_kernel.cc | 49 ++++++++++++++-- .../cpu/mkldnn/mkl_cpu_kernel.h | 44 +++++++++++++- .../cpu/mkldnn/mkl_kernel_engine.cc | 42 -------------- .../cpu/mkldnn/mkl_kernel_engine.h | 58 ------------------- .../cpu/mkldnn/pooling_avg_grad_cpu_kernel.cc | 6 +- .../cpu/mkldnn/pooling_cpu_kernel.cc | 3 +- .../cpu/mkldnn/pooling_max_grad_cpu_kernel.cc | 1 - .../cpu/mkldnn/softmax_cpu_kernel.cc | 3 +- ...ax_cross_entropy_with_logits_cpu_kernel.cc | 3 +- ...ax_cross_entropy_with_logits_cpu_kernel.cc | 3 +- .../kernel_compiler/cpu/random_cpu_kernel.cc | 7 ++- .../ccsrc/backend/session/kernel_graph.cc | 12 ++++ .../ccsrc/backend/session/kernel_graph.h | 1 + .../ccsrc/pybind_api/utils/ms_context_py.cc | 1 + .../runtime/framework/actor/actor_common.cc | 30 +++++----- .../ccsrc/runtime/framework/graph_compiler.cc | 2 +- .../runtime/framework/graph_scheduler.cc | 4 +- .../hardware/cpu/cpu_device_context.cc | 24 -------- .../ccsrc/runtime/hardware/device_context.h | 13 ----- mindspore/ccsrc/utils/utils.h | 2 - .../mindrt/src/thread/actor_threadpool.cc | 5 ++ .../core/mindrt/src/thread/threadpool.cc | 5 ++ mindspore/core/mindrt/src/thread/threadpool.h | 5 +- mindspore/core/utils/ms_context.cc | 6 ++ mindspore/core/utils/ms_context.h | 1 + mindspore/python/mindspore/context.py | 16 ++++- .../mindspore/ops/operations/array_ops.py | 6 ++ 46 files changed, 219 insertions(+), 231 deletions(-) delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h diff --git a/cmake/external_libs/mkl_dnn.cmake b/cmake/external_libs/mkl_dnn.cmake index be69f0296c0..785462d0f79 100644 --- a/cmake/external_libs/mkl_dnn.cmake +++ b/cmake/external_libs/mkl_dnn.cmake @@ -1,5 +1,10 @@ set(onednn_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(onednn_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") +if(USE_MS_THREADPOOL_FOR_DNNL) + set(USE_MS_THREADPOOL "-DDNNL_CPU_RUNTIME=THREADPOOL") +else() + set(USE_MS_THREADPOOL "") +endif() if(CMAKE_SYSTEM_NAME MATCHES "Windows") mindspore_add_pkg(onednn VER 2.2 @@ -22,7 +27,7 @@ else() URL ${REQ_URL} MD5 ${MD5} CMAKE_OPTION -DDNNL_ARCH_OPT_FLAGS='' -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF - -DDNNL_ENABLE_CONCURRENT_EXEC=ON) + ${USE_MS_THREADPOOL} -DDNNL_ENABLE_CONCURRENT_EXEC=ON) endif() include_directories(${onednn_INC}) diff --git a/cmake/options.cmake b/cmake/options.cmake index c3cc35928f9..01126e90b9d 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -28,6 +28,11 @@ option(ENABLE_SYM_FILE "enable sym file" OFF) option(BUILD_DEV_MODE "MindSpore build nightly dev mode" OFF) option(ENABLE_FAST_HASH_TABLE "Enable use fast hash table instead of std ones" ON) option(USE_LLVM "use llvm" OFF) +option(USE_MS_THREADPOOL_FOR_DNNL "use ms threadpool for onednn ops" ON) + +if(CMAKE_SYSTEM_NAME MATCHES "Windows") + set(USE_MS_THREADPOOL_FOR_DNNL OFF) +endif() if(NOT ENABLE_D AND NOT ENABLE_TESTCASES AND NOT ENABLE_ACL) set(ENABLE_GLIBCXX ON) @@ -161,3 +166,7 @@ endif() if(USE_LLVM) add_compile_definitions(USE_LLVM) endif() + +if(USE_MS_THREADPOOL_FOR_DNNL) + add_compile_definitions(USE_MS_THREADPOOL_FOR_DNNL) +endif() diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 0a1b3cf5897..05b8cccf8ba 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -18,6 +18,14 @@ if(ENABLE_D) add_subdirectory(backend/kernel_compiler/aicpu/aicpu_ops) endif() +# add openmp +find_package(OpenMP) +if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +else() + message(FATAL_ERROR "OpenMP not found") +endif() + if(ENABLE_CPU) if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") set(PLATFORM_ARM64 "on") diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc index 38bd285d061..4fdf2e94dc2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc @@ -26,6 +26,7 @@ namespace { constexpr size_t kAdamInputsNum = 10; constexpr size_t kAdamOutputsNum = 3; constexpr size_t kScalarIndex = 0; +constexpr float kAdamBlock = 1000; } // namespace template @@ -60,7 +61,7 @@ void AdamCPUKernel::LaunchAdam(const std::vector &inputs, co } } }; - ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); + ParallelLaunch(task, lens, kAdamBlock, this); } void AdamCPUKernel::LaunchAdamNnacl(const std::vector &inputs, @@ -89,7 +90,7 @@ void AdamCPUKernel::LaunchAdamNnacl(const std::vector &input MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', AdamFp32 failed. Error no: " << ret; } }; - ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); + ParallelLaunch(task, lens, kAdamBlock, this); } void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc index a0ccd89eb94..c69be189234 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index a096e14072c..2f58dd153c1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -31,9 +31,6 @@ #include "runtime/framework/graph_scheduler.h" #include "actor/actormgr.h" #include "common/thread_pool.h" -#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) -#define PLATFORM_86 -#endif using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc index 7c035d6e12f..6476dda4a05 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h" #include "backend/kernel_compiler/cpu/nnacl/errorcode.h" @@ -65,7 +64,7 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape); dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape); dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc); - auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::binary::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC_0, src0_mem_desc); AddArgument(DNNL_ARG_SRC_1, src1_mem_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc index d20cc201c04..57050aba9b1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -47,7 +46,7 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src0_desc = GetDefaultMemDesc(src0_shape); dnnl::memory::desc src1_desc = GetDefaultMemDesc(src1_shape); dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc); - auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::binary::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC_0, src0_desc); AddArgument(DNNL_ARG_SRC_1, src1_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.cc index d8ff1276fba..374ef049630 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -62,7 +61,7 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { } dnnl::batch_normalization_forward::desc desc = dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags); - auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, x_desc); AddArgument(DNNL_ARG_MEAN, prim_desc.mean_desc()); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc index 17cc3778821..0553aba9223 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -67,13 +66,13 @@ void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { // fused Batch Normalization forward description dnnl::batch_normalization_forward::desc desc = dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags); - auto forward_prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto forward_prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, engine_); // fused Batch Normalization backward description dnnl::batch_normalization_backward::desc backward_desc = dnnl::batch_normalization_backward::desc(dnnl::prop_kind::backward, x_desc, x_desc, epsilon, normalization_flags); - auto backward_prim_desc = dnnl::batch_normalization_backward::primitive_desc( - backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + auto backward_prim_desc = + dnnl::batch_normalization_backward::primitive_desc(backward_desc, engine_, forward_prim_desc); primitive_ = std::make_shared(backward_prim_desc); AddArgument(DNNL_ARG_SRC, x_desc); AddArgument(DNNL_ARG_MEAN, forward_prim_desc.mean_desc()); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc index f44f6297445..a1273a61bbe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc @@ -18,7 +18,6 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { @@ -86,13 +85,13 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_); dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc( dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - auto backward_prim_desc = dnnl::convolution_backward_weights::primitive_desc( - backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + auto backward_prim_desc = + dnnl::convolution_backward_weights::primitive_desc(backward_desc, engine_, forward_prim_desc); primitive_ = std::make_shared(backward_prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc index fa7a33147d1..167e63677dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc @@ -18,7 +18,6 @@ #include #include #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -99,13 +98,12 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, engine_); dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc( dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - auto backward_prim_desc = - dnnl::convolution_backward_data::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + auto backward_prim_desc = dnnl::convolution_backward_data::primitive_desc(backward_desc, engine_, forward_prim_desc); primitive_ = std::make_shared(backward_prim_desc); AddArgument(DNNL_ARG_DIFF_SRC, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.cc index b13dcb36292..4cb696f4a72 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.cc @@ -18,7 +18,6 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { @@ -107,7 +106,7 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); AddArgument(DNNL_ARG_WEIGHTS, weights_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc index fab0e74fb4c..9b385c821dd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc @@ -17,7 +17,6 @@ #include "backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h" #include #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -65,7 +64,7 @@ void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); auto desc = GetForwardEltwiseDesc(src_desc); - auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); AddArgument(DNNL_ARG_DST, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.cc index e520d5c70f9..83944d1b293 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "backend/kernel_compiler/cpu/mkldnn/log_softmax_cpu_kernel.h" #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -41,7 +40,7 @@ void LogSoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::logsoftmax_forward::desc desc = dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); - auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); AddArgument(DNNL_ARG_DST, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.cc index bbbd684810c..f0da4c216c0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "backend/kernel_compiler/cpu/mkldnn/log_softmax_grad_cpu_kernel.h" #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -41,11 +40,10 @@ void LogSoftmaxGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::logsoftmax_forward::desc desc = dnnl::logsoftmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); - auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::logsoftmax_forward::primitive_desc(desc, engine_); // backward description dnnl::logsoftmax_backward::desc backward_desc = dnnl::logsoftmax_backward::desc(src_desc, src_desc, axis); - auto backward_prim_desc = - dnnl::logsoftmax_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), prim_desc); + auto backward_prim_desc = dnnl::logsoftmax_backward::primitive_desc(backward_desc, engine_, prim_desc); primitive_ = std::make_shared(backward_prim_desc); AddArgument(DNNL_ARG_DST, src_desc); AddArgument(DNNL_ARG_DIFF_SRC, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc index b24e10274da..9d352f2e140 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc @@ -17,12 +17,7 @@ #include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h" #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" -#ifdef PLATFORM_86 -#include -#endif - namespace mindspore { namespace kernel { namespace { @@ -54,14 +49,10 @@ void LstmCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { } void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { -#ifdef PLATFORM_86 - _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); -#endif MS_EXCEPTION_IF_NULL(kernel_node); kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); CheckParam(kernel_node); - auto eng = MKLKernelEngine::Get().engine(); + auto eng = engine_; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; if (bidirectional_) { direction = dnnl::rnn_direction::bidirectional_concat; @@ -153,7 +144,7 @@ bool LstmCPUKernel::Launch(const std::vector &inputs, const const std::vector &outputs) { using dt = dnnl::memory::data_type; using tag = dnnl::memory::format_tag; - auto eng = MKLKernelEngine::Get().engine(); + auto eng = engine_; auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc index 61fde843987..2b5b0eec2ef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -18,7 +18,6 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { @@ -43,7 +42,7 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); CheckParam(kernel_node); - auto eng = MKLKernelEngine::Get().engine(); + auto eng = engine_; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; if (bidirectional_) { direction = dnnl::rnn_direction::bidirectional_concat; @@ -178,7 +177,7 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, co const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstmGradInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstmGradOutputsNum, kernel_name_); - auto eng = MKLKernelEngine::Get().engine(); + auto eng = engine_; // construct fw memory auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc index 634ba4a3e3c..b6600eaaf10 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc @@ -17,7 +17,6 @@ #include "backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h" #include #include "common/thread_pool.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/nnacl/op_base.h" #include "backend/kernel_compiler/cpu/nnacl/matmul_parameter.h" #include "backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h" @@ -81,7 +80,7 @@ void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc weights_md(weights_dims, dnnl::memory::data_type::f32, b_strides); dnnl::memory::desc dst_md(dst_dims, dnnl::memory::data_type::f32, o_strides); dnnl::matmul::desc matmul_desc(src_md, weights_md, dst_md); - dnnl::matmul::primitive_desc prim_desc(matmul_desc, MKLKernelEngine::Get().engine()); + dnnl::matmul::primitive_desc prim_desc(matmul_desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, src_md); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc index 88be1110c49..0705dd4d65c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc @@ -19,7 +19,7 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "utils/profile.h" namespace mindspore { namespace kernel { @@ -139,7 +139,11 @@ dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector &sh } void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc) { - arguments_[arg_key] = MKLKernelEngine::Get().CreateMemory(mem_desc, alloc); + if (alloc) { + arguments_[arg_key] = dnnl::memory(mem_desc, engine_); + } else { + arguments_[arg_key] = dnnl::memory(mem_desc, engine_, nullptr); + } } void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { @@ -149,10 +153,47 @@ void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { } } -void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } +void MKLCPUKernel::ExecutePrimitive() { + MS_EXCEPTION_IF_NULL(primitive_); +#ifdef USE_MS_THREADPOOL_FOR_DNNL + // add auto search + const size_t MAX_POW = 6; + const size_t AVG_COUNT = 5; + const size_t DIFF = 2; + size_t current_pow = parallel_search_info_.search_count / AVG_COUNT; + int current_thread_nums = static_cast(std::pow(2.0f, current_pow)); + auto mkl_pool = dynamic_cast(mkl_threadpool_.get()); + if (current_pow < MAX_POW) { + if (parallel_search_info_.search_count % AVG_COUNT == 0) { + parallel_search_info_.tmp_sum_cost_time = 0; + } + double start_time = GetTime(); + mkl_pool->set_num_threads(current_thread_nums); + primitive_->execute(stream_, arguments_); + double cost_time = GetTime() - start_time; + parallel_search_info_.tmp_sum_cost_time += cost_time; + parallel_search_info_.search_count++; + if (parallel_search_info_.search_count % AVG_COUNT == 0) { + if (parallel_search_info_.min_cost_time > parallel_search_info_.tmp_sum_cost_time) { + parallel_search_info_.min_cost_time = parallel_search_info_.tmp_sum_cost_time; + parallel_search_info_.best_pow = current_pow; + } else if (current_pow - parallel_search_info_.best_pow >= DIFF) { + parallel_search_info_.search_count = AVG_COUNT * MAX_POW; + } + } + } else { + int best_thread_nums = static_cast(std::pow(2.0f, parallel_search_info_.best_pow)); + mkl_pool->set_num_threads(best_thread_nums); + primitive_->execute(stream_, arguments_); + } +#else + primitive_->execute(stream_, arguments_); +#endif + (void)stream_.wait(); +} void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { - MKLKernelEngine::Get().Reorder(src_mem, dst_mem); + dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h index 8990e33a5c8..a7ab16492a7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h @@ -19,17 +19,54 @@ #include #include +#include #include #include #include "dnnl.hpp" #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#ifdef USE_MS_THREADPOOL_FOR_DNNL +#include "dnnl_threadpool.hpp" +#include "dnnl_threadpool_iface.hpp" +#endif namespace mindspore { namespace kernel { +#ifdef USE_MS_THREADPOOL_FOR_DNNL +class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface { + private: + ActorThreadPool *tp_; + int thread_num_{8}; + + public: + explicit mkl_threadpool(ActorThreadPool *tp) { tp_ = tp; } + void set_num_threads(int num) { thread_num_ = num; } + int get_num_threads() const override { return std::min(SizeToInt(tp_->GetKernelThreadNum()), thread_num_); } + bool get_in_parallel() const override { return false; } + uint64_t get_flags() const override { return 0; } + void parallel_for(int n, const std::function &fn) override { + int nthr = get_num_threads(); + int n_jobs = std::min(n, nthr); + auto func = [&, n_jobs](void *, int i, float, float) { + fn(i, n_jobs); + return 0; + }; + tp_->ParallelLaunch(func, nullptr, n_jobs); + } +}; +#endif + class MKLCPUKernel : public CPUKernel { public: - MKLCPUKernel() = default; +#ifdef USE_MS_THREADPOOL_FOR_DNNL + MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0) { + auto thread_pool = GetActorMgrInnerThreadPool(); + mkl_threadpool_ = std::make_shared(thread_pool); + stream_ = dnnl::threadpool_interop::make_stream(engine_, mkl_threadpool_.get()); + } +#else + MKLCPUKernel() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} +#endif ~MKLCPUKernel() override = default; protected: @@ -50,6 +87,11 @@ class MKLCPUKernel : public CPUKernel { std::unordered_map arguments_; std::shared_ptr primitive_{nullptr}; + dnnl::engine engine_; + dnnl::stream stream_; +#ifdef USE_MS_THREADPOOL_FOR_DNNL + std::shared_ptr mkl_threadpool_{nullptr}; +#endif }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc deleted file mode 100644 index ee42d3a995e..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" -#include "utils/log_adapter.h" -#include "dnnl.hpp" - -namespace mindspore { -namespace kernel { -void MKLKernelEngine::Execute(const std::shared_ptr &primitive, - const std::unordered_map &arguments) { - MS_EXCEPTION_IF_NULL(primitive); - primitive->execute(stream_, arguments); - (void)stream_.wait(); -} - -dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc) { - if (alloc) { - return dnnl::memory(mem_desc, engine_); - } else { - return dnnl::memory(mem_desc, engine_, nullptr); - } -} - -void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { - dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h deleted file mode 100644 index 6d33172af51..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_KERNEL_ENGINE_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_KERNEL_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dnnl.hpp" -#include "utils/ms_utils.h" - -namespace mindspore { -namespace kernel { -class MKLKernelEngine { - public: - static MKLKernelEngine &Get() { - static MKLKernelEngine instance; - return instance; - } - DISABLE_COPY_AND_ASSIGN(MKLKernelEngine) - - const dnnl::engine &engine() const { return engine_; } - - dnnl::memory CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc = false); - - void Execute(const std::shared_ptr &primitive, - const std::unordered_map &arguments); - void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); - - private: - MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} - ~MKLKernelEngine() = default; - - dnnl::engine engine_; - dnnl::stream stream_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MKL_KERNEL_ENGINE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.cc index b32c555516c..07a772de8a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.cc @@ -19,7 +19,6 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { @@ -69,13 +68,12 @@ void AvgPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::pooling_forward::desc desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r); - auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_); // pooling_avg backward description dnnl::pooling_backward::desc backward_desc = dnnl::pooling_backward::desc( dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r); - auto backward_prim_desc = - dnnl::pooling_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), prim_desc); + auto backward_prim_desc = dnnl::pooling_backward::primitive_desc(backward_desc, engine_, prim_desc); primitive_ = std::make_shared(backward_prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc index 4fa0d713653..aa17ff01283 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc @@ -17,7 +17,6 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { @@ -84,7 +83,7 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r); } - auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, engine_); workspace_size_ = prim_desc.workspace_desc().get_size(); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_max_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_max_grad_cpu_kernel.cc index 820873c032c..c9989ca237b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_max_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_max_grad_cpu_kernel.cc @@ -19,7 +19,6 @@ #include #include #include "utils/ms_utils.h" -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc index 7e599033ad5..f93660478c3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h" #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -47,7 +46,7 @@ void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { } dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, src_desc); AddArgument(DNNL_ARG_DST, src_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc index b78085155b2..d1ab5806e70 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -18,7 +18,6 @@ #include #include #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -56,7 +55,7 @@ void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_n dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, mem_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc index 44a562a7d98..198f0602e98 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -19,7 +19,6 @@ #include #include #include -#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -62,7 +61,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, engine_); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, mem_desc); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc index da817ebab8e..b8d650e2241 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc @@ -26,6 +26,7 @@ constexpr size_t kUniformRealInputsNum = 1; constexpr size_t kUniformIntOutputsNum = 1; constexpr size_t kUniformRealOutputsNum = 1; constexpr size_t kStandardNormalOutputsNum = 1; +constexpr float kRandomBlockSize = 128.0; constexpr char kKernelName[] = "Random"; } // namespace void StandardNormal(float *output, std::normal_distribution distribution, @@ -39,12 +40,12 @@ void LaunchStandardNormal(RandomCPUKernel *content, unsigned int seed, const std auto output = reinterpret_cast(outputs[0]->addr); // multithreading size_t lens = outputs[0]->size / sizeof(float); - auto task = [&seed, &output](size_t start, size_t end) { + std::default_random_engine random_generator(++seed); + auto task = [&seed, &output, &random_generator](size_t start, size_t end) { std::normal_distribution distribution; - std::default_random_engine random_generator(++seed); StandardNormal(output, distribution, random_generator, start, end); }; - ParallelLaunchAutoSearch(task, lens, content, &content->parallel_search_info_); + ParallelLaunch(task, lens, kRandomBlockSize, content); } void LaunchUniformInt(unsigned int seed, const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index dcda24c1942..3b90d470cbe 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1135,6 +1135,18 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr SetInternalOutputAttr(new_node); } +void KernelGraph::EnableRuntimeCache() { + auto node_list = TopoSort(get_return()); + for (auto &node : node_list) { + auto kernel_info = node->kernel_info(); + if (!kernel_info) { + continue; + } + auto runtime_cache = kernel_info->runtime_cache(); + runtime_cache.runtime_cache().set_valid(); + } +} + void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx, size_t dst_output_idx) { if (new_node == nullptr || node == nullptr) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 10dd5738355..bdba3bd341a 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -286,6 +286,7 @@ class KernelGraph : public FuncGraph { } } void RemoveNodeFromGraph(const AnfNodePtr &node); + void EnableRuntimeCache(); void UpdateGraphDynamicAttr(); void SetGraphDynamicAttr(bool is_dynamic_shape) { is_dynamic_shape_ = is_dynamic_shape; } bool is_dynamic_shape() const { return is_dynamic_shape_; } diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc index abffad0f7a7..d9e4b8fca17 100644 --- a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -87,6 +87,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { .value("mempool_block_size", MsCtxParam::MS_CTX_MEMPOOL_BLOCK_SIZE) .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) + .value("runtime_num_threads", MsCtxParam::MS_CTX_RUNTIME_NUM_THREADS) .value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc index 61b9be70a46..e7adb049145 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc @@ -25,23 +25,27 @@ bool ActorDispatcher::is_multi_thread_execution_ = true; void ComputeThreadNums(size_t *actor_thread_num, size_t *actor_and_kernel_thread_num) { MS_EXCEPTION_IF_NULL(actor_thread_num); MS_EXCEPTION_IF_NULL(actor_and_kernel_thread_num); - const size_t cpu_core_num = std::thread::hardware_concurrency() - 1; - // Compute the actor thread num. - const size_t kActorThreadMaxNum = 5; - // The MemoryManagerActor binds single thread, and the other actors share one thread at least, so the min num is 2. - const size_t kActorThreadMinNum = 2; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - *actor_thread_num = cpu_core_num < kActorThreadMinNum ? kActorThreadMinNum : cpu_core_num; - *actor_thread_num = *actor_thread_num > kActorThreadMaxNum ? kActorThreadMaxNum : *actor_thread_num; + const size_t cpu_core_num = std::thread::hardware_concurrency() - 1; + auto runtime_num_threads = static_cast(context_ptr->get_param(MS_CTX_RUNTIME_NUM_THREADS)); + size_t runtime_num_threads_min = std::min(runtime_num_threads, cpu_core_num); + const float kActorUsage = 0.2; + const size_t kActorThreadMinNum = 2; + size_t actor_thread_max_num = + std::max(static_cast(std::floor(runtime_num_threads_min * kActorUsage)), kActorThreadMinNum); + // Compute the actor thread num. + // The MemoryManagerActor binds single thread, and the other actors share one thread at least, so the min num is 2. + *actor_thread_num = runtime_num_threads_min < kActorThreadMinNum ? kActorThreadMinNum : runtime_num_threads_min; + *actor_thread_num = *actor_thread_num > actor_thread_max_num ? actor_thread_max_num : *actor_thread_num; // Compute the actor and kernel thread num. - const size_t kKernelThreadMaxNum = 23; - const size_t kActorAndKernelThreadMaxNum = kKernelThreadMaxNum + *actor_thread_num; - *actor_and_kernel_thread_num = cpu_core_num > *actor_thread_num ? cpu_core_num : (*actor_thread_num + 1); - *actor_and_kernel_thread_num = *actor_and_kernel_thread_num > kActorAndKernelThreadMaxNum - ? kActorAndKernelThreadMaxNum - : *actor_and_kernel_thread_num; + *actor_and_kernel_thread_num = + runtime_num_threads_min > *actor_thread_num ? runtime_num_threads_min : (*actor_thread_num + 1); + if (runtime_num_threads != *actor_and_kernel_thread_num) { + MS_LOG(WARNING) << "The runtime_num_threads is " << runtime_num_threads + << ", but actually the num of threads in threadpool is " << *actor_and_kernel_thread_num; + } } bool IsDeviceQueueDSActor(const AnfNodePtr &node, GraphExecutionStrategy strategy) { diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index bb125c02fc7..bcf1b8f1150 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -492,7 +492,7 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic } #endif - device_context->EnableRuntimeCache(graph); + graph->EnableRuntimeCache(); return graph->graph_id(); } diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 5e2aef6624b..2f9bb92fb2e 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -265,10 +265,8 @@ void GraphScheduler::Initialize() { MS_LOG(EXCEPTION) << "Actor manager init failed."; } common::SetOMPThreadNum(); - auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS"); MS_LOG(INFO) << "The actor thread number: " << actor_thread_num - << ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num) - << ", the used OMP thread number: " << OMP_thread_num_used; + << ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num); BuildAndScheduleGlobalActor(); } diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc index 62c56bd5892..701cc5dc7df 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc @@ -40,20 +40,12 @@ #ifndef ENABLE_SECURITY #include "debug/data_dump/dump_json_parser.h" #endif -#ifdef PLATFORM_86 -#include -#endif namespace mindspore { namespace device { namespace cpu { using mindspore::kernel::KernelBuildInfo; -#ifdef PLATFORM_86 -// Whether need set the flush zero mode in the kernel launch. -static bool flush_zero_mode_enable{false}; -#endif - void CPUDeviceContext::Initialize() { if (initialized_) { return; @@ -229,14 +221,6 @@ void CPUDeviceContext::CreateKernel(const std::vector &nodes) const { MS_LOG(EXCEPTION) << "Build cpu operator[" << node->fullname_with_scope() << "] failed"; } -#ifdef PLATFORM_86 - // Some CPU kernels need set the flush zero mode to improve performance. - if (!flush_zero_mode_enable && - (kOpNeedSetFlushZeroModeList.find(kernel_name) != kOpNeedSetFlushZeroModeList.end())) { - flush_zero_mode_enable = true; - } -#endif - cpu_kernel->Init(node); cpu_kernel->InitDynamicKernel(node); auto cpu_dynamic_kernel = cpu_kernel->DynamicKernel(); @@ -285,14 +269,6 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vectorTopoSort(graph->get_return()); - for (auto &node : node_list) { - auto kernel_info = node->kernel_info(); - if (!kernel_info) { - continue; - } - MS_EXCEPTION_IF_NULL(kernel_info); - auto runtime_cache = kernel_info->runtime_cache(); - runtime_cache.runtime_cache().set_valid(); - } - } - protected: DeviceContextKey device_context_key_; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 7c6264a3ecf..a7bed8cee93 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -716,8 +716,6 @@ const std::set kPosteriorOperatorSet = {kPullOpName}; const std::set kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName, kGetNextOpName}; -const std::set kOpNeedSetFlushZeroModeList = {kLSTMOpName, kLSTMGradOpName}; - const std::set kOpNotSupportMultiThreadExecList = {kAvgPoolOpName, kAvgPoolGradOpName, kMaxPoolOpName, kBatchNorm, kBatchNormGradOpName}; diff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc index ed94f117899..d91bcbf86ed 100644 --- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc +++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc @@ -33,6 +33,11 @@ void ActorWorker::RunWithSpin() { #if !defined(__APPLE__) && !defined(SUPPORT_MSVC) static std::atomic_int index = {0}; (void)pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str()); +#endif +#ifdef PLATFORM_86 + // Some CPU kernels need set the flush zero mode to improve performance. + _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); #endif while (alive_) { // only run either local KernelTask or PoolQueue ActorTask diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc index 731717aa8df..db7b2f0d1d4 100644 --- a/mindspore/core/mindrt/src/thread/threadpool.cc +++ b/mindspore/core/mindrt/src/thread/threadpool.cc @@ -76,6 +76,11 @@ void Worker::Run() { #if !defined(__APPLE__) && !defined(SUPPORT_MSVC) static std::atomic_int index = {0}; (void)pthread_setname_np(pthread_self(), ("KernelThread_" + std::to_string(index++)).c_str()); +#endif +#ifdef PLATFORM_86 + // Some CPU kernels need set the flush zero mode to improve performance. + _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); #endif while (alive_) { if (RunLocalKernelTask()) { diff --git a/mindspore/core/mindrt/src/thread/threadpool.h b/mindspore/core/mindrt/src/thread/threadpool.h index 1c07bf37b7d..b0c3ed27535 100644 --- a/mindspore/core/mindrt/src/thread/threadpool.h +++ b/mindspore/core/mindrt/src/thread/threadpool.h @@ -27,7 +27,10 @@ #include #include "thread/threadlog.h" #include "thread/core_affinity.h" - +#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) +#define PLATFORM_86 +#include +#endif namespace mindspore { constexpr int kDefaultSpinCount = 300000; constexpr int kMaxCount = 30000; diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 871fe50c897..7280573a58d 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "ir/tensor.h" #include "utils/ms_utils.h" @@ -97,6 +98,11 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { set_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false); set_param(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true); + size_t cpu_core_num = std::thread::hardware_concurrency() - 1; + constexpr float kCpuUsage = 0.6; + uint32_t runtime_num_threads = std::max(static_cast(std::floor(cpu_core_num * kCpuUsage)), 1); + set_param(MS_CTX_RUNTIME_NUM_THREADS, runtime_num_threads); + backend_policy_ = policy_map_[policy]; } diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index c932fd291b5..7017ab70d56 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -102,6 +102,7 @@ enum MsCtxParam : unsigned { // parameter of type uint32 MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, + MS_CTX_RUNTIME_NUM_THREADS, MS_CTX_GE_REF, MS_CTX_MAX_CALL_DEPTH, MS_CTX_TSD_REF, diff --git a/mindspore/python/mindspore/context.py b/mindspore/python/mindspore/context.py index 71dd629047b..7ff3d037ea4 100644 --- a/mindspore/python/mindspore/context.py +++ b/mindspore/python/mindspore/context.py @@ -318,6 +318,12 @@ class _Context: "have permission to read it.".format(env_config_path)) self.set_param(ms_ctx_param.env_config_path, env_config_path) + def set_runtime_num_threads(self, runtime_num_threads): + """Check and set runtime_num_threads.""" + if runtime_num_threads <= 0: + raise ValueError("The num of cpu thread must bigger than 0.") + self.set_param(ms_ctx_param.runtime_num_threads, runtime_num_threads) + setters = { 'mode': set_mode, 'save_graphs_path': set_save_graphs_path, @@ -331,7 +337,8 @@ class _Context: 'mempool_block_size': set_mempool_block_size, 'print_file_path': set_print_file_path, 'env_config_path': set_env_config_path, - 'backend_policy': set_backend_policy + 'backend_policy': set_backend_policy, + 'runtime_num_threads': set_runtime_num_threads } @property @@ -607,7 +614,7 @@ def _check_target_specific_cfgs(device, arg_key): enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int, - env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, + env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, backend_policy=str) def set_context(**kwargs): @@ -679,6 +686,8 @@ def set_context(**kwargs): | +------------------------------+----------------------------+ | | enable_compile_cache | CPU/GPU/Ascend | | +------------------------------+----------------------------+ + | | runtime_num_threads | CPU/GPU/Ascend | + | +------------------------------+----------------------------+ | | compile_cache_path | CPU/GPU/Ascend | | +------------------------------+----------------------------+ | | backend_policy | Ascend | @@ -822,6 +831,8 @@ def set_context(**kwargs): backend_policy (str): Used to choose a backend. ("ge", "vm" or "ms"). Through context.set_context(backend_policy="ms") Default: The value must be in ['ge', 'vm', 'ms']. + runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime, + which must bigger than 0. Raises: ValueError: If input key is not an attribute in context. @@ -851,6 +862,7 @@ def set_context(**kwargs): >>> context.set_context(grad_for_scalar=True) >>> context.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms") >>> context.set_context(pynative_synchronize=True) + >>> context.set_context(runtime_num_threads=10) """ ctx = _context() # set device target first diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 58902274c11..9a023e237d9 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -26,6 +26,7 @@ from collections import Counter import numpy as np from mindspore import log as logger +from mindspore import context from mindspore.common.initializer import Zero from .. import signature as sig from .._utils import get_broadcast_shape, is_shape_unknown @@ -3304,6 +3305,11 @@ class StridedSlice(PrimitiveWithInfer): shrink_axis_mask=0): """Initialize StridedSlice""" self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) + + # auto parallel haven't support begin_mask and end_mask + if context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]: + begin_mask = 0 + end_mask = 0 validator.check_non_negative_int(begin_mask, 'begin_mask', self.name) validator.check_non_negative_int(end_mask, 'end_mask', self.name) validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)