forked from mindspore-Ecosystem/mindspore
!19855 [bugfix] Reinitialize for cpu kernels which do not support multi-thread
Merge pull request !19855 from zyli2020/r1.3_mindrt
This commit is contained in:
commit
d141ee0c72
|
@ -203,6 +203,16 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
bool) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto cpu_kernel_mod = dynamic_cast<kernel::CPUKernel *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(cpu_kernel_mod);
|
||||
|
||||
// Some CPU kernels can't initialize kernel and launch kernel in different thread, so reinitialize the kernels before
|
||||
// launch.
|
||||
if (kOpNotSupportMultiThreadExecList.find(AnfAlgo::GetCNodeName(kernel)) != kOpNotSupportMultiThreadExecList.end()) {
|
||||
cpu_kernel_mod->InitKernel(kernel);
|
||||
}
|
||||
|
||||
const auto &profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(profiler_inst);
|
||||
|
@ -210,8 +220,6 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
return LaunchKernelWithProfiling(kernel, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
|
|
|
@ -233,6 +233,7 @@ constexpr auto kPaddingOpName = "Padding";
|
|||
constexpr auto kAvgPoolOpName = "AvgPool";
|
||||
constexpr auto kAvgPoolGradOpName = "AvgPoolGrad";
|
||||
constexpr auto kAvgPoolGradVmOpName = "AvgPoolGradVm";
|
||||
constexpr auto kMaxPoolOpName = "MaxPool";
|
||||
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
|
||||
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
|
||||
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
|
||||
|
@ -596,6 +597,8 @@ const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
|
|||
const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
|
||||
kGetNextOpName};
|
||||
|
||||
const std::set<std::string> kOpNotSupportMultiThreadExecList = {kAvgPoolOpName, kAvgPoolGradOpName, kMaxPoolOpName};
|
||||
|
||||
const std::set<std::string> kHWSpecialFormatSet = {
|
||||
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};
|
||||
|
|
Loading…
Reference in New Issue