!19403 [bugfix] CPU PyNative hook occur select op error
Merge pull request !19403 from zyli2020/fix_mindtest_bug
This commit is contained in:
commit
83bcf936b6
|
@ -19,6 +19,7 @@
|
|||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/cpu_memory_manager.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
@ -34,6 +35,8 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
using mindspore::kernel::KernelBuildInfo;
|
||||
|
||||
bool CPUDeviceContext::Initialize() {
|
||||
if (initialized_) {
|
||||
return true;
|
||||
|
@ -115,15 +118,51 @@ void CPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr>
|
|||
graph->UpdateGraphDynamicAttr();
|
||||
}
|
||||
|
||||
namespace {
|
||||
void SetControlOpInfo(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
|
||||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
builder->SetInputsFormat(inputs_format);
|
||||
builder->SetInputsDeviceType(inputs_type);
|
||||
builder->SetOutputsFormat(outputs_format);
|
||||
builder->SetOutputsDeviceType(outputs_type);
|
||||
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto &node : nodes) {
|
||||
SetKernelInfo(node);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::IsControlOpExecInBackend(node)) {
|
||||
SetKernelInfo(node);
|
||||
} else {
|
||||
SetControlOpInfo(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::IsControlOpExecInBackend(node)) {
|
||||
continue;
|
||||
}
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(node);
|
||||
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kernel_name, node);
|
||||
if (!cpu_kernel) {
|
||||
|
|
Loading…
Reference in New Issue