diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc index 8af9b65addf..ab530fdc8c6 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc @@ -19,6 +19,7 @@ #include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_memory_manager.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "runtime/device/cpu/kernel_select_cpu.h" #include "utils/trace_base.h" #include "backend/optimizer/common/optimizer.h" @@ -34,6 +35,8 @@ namespace mindspore { namespace device { namespace cpu { +using mindspore::kernel::KernelBuildInfo; + bool CPUDeviceContext::Initialize() { if (initialized_) { return true; @@ -115,15 +118,51 @@ void CPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull graph->UpdateGraphDynamicAttr(); } +namespace { +void SetControlOpInfo(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector inputs_format; + std::vector inputs_type; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + std::vector outputs_format; + std::vector outputs_type; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + + auto builder = std::make_shared(); + builder->SetInputsFormat(inputs_format); + builder->SetInputsDeviceType(inputs_type); + builder->SetOutputsFormat(outputs_format); + builder->SetOutputsDeviceType(outputs_type); + + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); +} +} // namespace + void CPUDeviceContext::SetOperatorInfo(const std::vector &nodes) const { for (const auto &node : nodes) { - SetKernelInfo(node); + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsControlOpExecInBackend(node)) { + SetKernelInfo(node); + } else { + SetControlOpInfo(node); + } } } void CPUDeviceContext::CreateKernel(const std::vector &nodes) const { for (const auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsControlOpExecInBackend(node)) { + continue; + } std::string kernel_name = AnfAlgo::GetCNodeName(node); std::shared_ptr cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kernel_name, node); if (!cpu_kernel) {