diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 507bcbb6c1e..48725ce7f6d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -362,13 +362,13 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector outp } auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - bool need_convert = context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK); - if (need_convert) { - need_convert = - trans::kTransFormatMapOfHostToDevice.find(output_format[0]) != trans::kTransFormatMapOfHostToDevice.end(); + bool disable_convert = real_input_node->isa() || real_input_node->isa(); + if (disable_convert && context_ptr->get_param(MS_CTX_ENABLE_LOOP_SINK)) { + disable_convert = + trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end(); } // if not find in host convert format map means the host has not registered the convert function of this format - if (real_input_node->isa() && output_format[0] != kOpFormat_DEFAULT && !need_convert) { + if (output_format[0] != kOpFormat_DEFAULT && disable_convert) { output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; } auto builder = std::make_shared();