!19855 [bugfix] Reinitialize for cpu kernels which do not support multi-thread

Merge pull request !19855 from zyli2020/r1.3_mindrt
This commit is contained in:
i-robot 2021-07-09 21:42:58 +00:00 committed by Gitee
commit d141ee0c72
2 changed files with 13 additions and 2 deletions

View File

@ -203,6 +203,16 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
bool) const {
MS_EXCEPTION_IF_NULL(kernel);
MS_LOG(DEBUG) << "Launch kernel: " << kernel->fullname_with_scope();
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto cpu_kernel_mod = dynamic_cast<kernel::CPUKernel *>(kernel_mod);
MS_EXCEPTION_IF_NULL(cpu_kernel_mod);
// Some CPU kernels can't initialize kernel and launch kernel in different thread, so reinitialize the kernels before
// launch.
if (kOpNotSupportMultiThreadExecList.find(AnfAlgo::GetCNodeName(kernel)) != kOpNotSupportMultiThreadExecList.end()) {
cpu_kernel_mod->InitKernel(kernel);
}
const auto &profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
MS_EXCEPTION_IF_NULL(profiler_inst);
@ -210,8 +220,6 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
return LaunchKernelWithProfiling(kernel, inputs, workspace, outputs);
}
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
return DoLaunchKernel(kernel_mod, inputs, workspace, outputs);
}

View File

@ -233,6 +233,7 @@ constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradOpName = "AvgPoolGrad";
constexpr auto kAvgPoolGradVmOpName = "AvgPoolGradVm";
constexpr auto kMaxPoolOpName = "MaxPool";
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
@ -596,6 +597,8 @@ const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
kGetNextOpName};
const std::set<std::string> kOpNotSupportMultiThreadExecList = {kAvgPoolOpName, kAvgPoolGradOpName, kMaxPoolOpName};
const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z};