!19403 [bugfix] CPU PyNative hook occur select op error

Merge pull request !19403 from zyli2020/fix_mindtest_bug
This commit is contained in:
i-robot 2021-07-06 02:25:12 +00:00 committed by Gitee
commit 83bcf936b6
1 changed files with 40 additions and 1 deletions

View File

@ -19,6 +19,7 @@
#include "runtime/device/cpu/cpu_device_address.h"
#include "runtime/device/cpu/cpu_memory_manager.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/trace_base.h"
#include "backend/optimizer/common/optimizer.h"
@ -34,6 +35,8 @@
namespace mindspore {
namespace device {
namespace cpu {
using mindspore::kernel::KernelBuildInfo;
bool CPUDeviceContext::Initialize() {
if (initialized_) {
return true;
@ -115,15 +118,51 @@ void CPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr>
graph->UpdateGraphDynamicAttr();
}
namespace {
void SetControlOpInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
}
} // namespace
void CPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
SetKernelInfo(node);
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsControlOpExecInBackend(node)) {
SetKernelInfo(node);
} else {
SetControlOpInfo(node);
}
}
}
void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsControlOpExecInBackend(node)) {
continue;
}
std::string kernel_name = AnfAlgo::GetCNodeName(node);
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kernel_name, node);
if (!cpu_kernel) {