diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 41a737f61ca..db986f03f74 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -153,7 +153,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false); - if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) { kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0); } make_tuple_inputs.emplace_back(trans_op); @@ -265,7 +265,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP // Single output if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select); - if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) { kernel_graph->ReplaceInternalOutput(node, new_node); } return new_node; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc index bc68511bb2b..fd4ed02eb6d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -65,7 +65,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_scope(cnode->scope()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); - if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) { + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) { kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0); } } else { @@ -114,7 +114,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_scope(cnode->scope()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); - if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) { + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) { kernel_graph->ReplaceInternalOutput(cnode, replace_node); } } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc index e9238fe0066..530e180ed2e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc @@ -58,7 +58,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c if (kernel_graph == nullptr) { return nullptr; } - if (!kernel_graph->IsInternalOutput(node)) { + if (!kernel_graph->IsInternalOutput(node, 0)) { return nullptr; } if (!UsedForOutputOnly(func_graph, node)) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index a89b1579eab..b9e4e003bf2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1021,26 +1021,16 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod return nullptr; } -bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { - if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { - return true; - } - return false; -} - -void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { - if (node == nullptr) { - return; - } - (void)final_output_kernels_.insert(node); -} - -bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { - if (node == nullptr) { - return false; - } - if (final_output_kernels_.find(node) != final_output_kernels_.end()) { - return true; +bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const { + auto front_nodes_iter = internal_outputs_to_front_map_.find(node); + if (front_nodes_iter != internal_outputs_to_front_map_.end()) { + if (output_idx == -1) { + return true; + } + auto &front_nodes = front_nodes_iter->second; + if (front_nodes.find(output_idx) != front_nodes.end()) { + return true; + } } return false; } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 05a12aebf37..7c170a37af2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -153,9 +153,7 @@ class KernelGraph : public FuncGraph { void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1, int dst_output_idx = -1); AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; - bool IsInternalOutput(const AnfNodePtr &node) const; - void AddFinalOutputKernel(const AnfNodePtr &node); - bool IsFinalOutputKernel(const AnfNodePtr &node) const; + bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const; uint32_t current_epoch() const { return current_epoch_; } void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } void UpdateChildGraphOrder(); @@ -230,7 +228,6 @@ class KernelGraph : public FuncGraph { bool null_output_; std::unordered_map front_to_internal_outputs_map_; std::unordered_map> internal_outputs_to_front_map_; - std::set final_output_kernels_; uint32_t current_epoch_; }; } // namespace session diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ff72716c607..b551040bbda 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -89,7 +89,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne TypeId type_id = kNumberTypeFloat32; type_id = AnfAlgo::GetOutputInferDataType(node, output_index); std::vector temp_shape; - if (graph.IsInternalOutput(node)) { + if (graph.IsInternalOutput(node, output_index)) { temp_shape.emplace_back(1); tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); tensor->set_device_address(address); @@ -307,8 +307,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx); auto ref_real_node = real_kernel.first; auto ref_real_node_index = real_kernel.second; - if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && - node_graph->IsFinalOutputKernel(ref_real_node)) { + if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) { auto kernel_info = ref_real_node->kernel_info(); if (kernel_info == nullptr || !kernel_info->has_build_info()) { MS_LOG(INFO) << "No kernel info"; diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index e8fc6c7a985..3932c4fd5aa 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -492,6 +492,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { continue; } + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown && + AnfAlgo::OutputAddrExist(real_input_node, 0)) { + continue; + } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 6571bc22d69..b51e9127c50 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -355,7 +355,6 @@ void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { continue; } - graph->AddFinalOutputKernel(item_with_index.first); if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { AssignCommunicationNodeMem(kStaticMem, item_with_index.first); } else { diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc index 72b7c6e3614..8c3ed29a0cb 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc @@ -122,7 +122,7 @@ TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_sing auto make_tuple = GetMakeTuple(kg); auto trans_data = make_tuple->cast()->input(1); - EXPECT_TRUE(kg->IsInternalOutput(trans_data)); + EXPECT_TRUE(kg->IsInternalOutput(trans_data, 0)); // remove trans op for internal output auto graph_optimizer1 = std::make_shared(); @@ -156,8 +156,8 @@ TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_mult auto make_tuple1 = tuple_getitem->cast()->input(1); auto trans_data1 = make_tuple1->cast()->input(1); auto trans_data2 = make_tuple1->cast()->input(2); - EXPECT_TRUE(kg->IsInternalOutput(trans_data1)); - EXPECT_TRUE(kg->IsInternalOutput(trans_data2)); + EXPECT_TRUE(kg->IsInternalOutput(trans_data1, 0)); + EXPECT_TRUE(kg->IsInternalOutput(trans_data2, 0)); // remove trans op for internal output auto graph_optimizer1 = std::make_shared();