diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 5c987d3291f..d512aa00857 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -225,13 +225,6 @@ std::vector GetAllOutputWithIndexInner(const AnfNodePtr &node) continue; } - // Fetch outputs by control nodes. - if (AnfAlgo::IsCallNode(node)) { - const auto &control_node_output = AnfAlgo::GetAllOutputByCallNode(output_with_index); - (void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret)); - continue; - } - // The InitDataSetQueue node has no output. if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) { return ret_empty; @@ -398,6 +391,7 @@ size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_a MS_EXCEPTION_IF_NULL(tuple_abstract); const auto &sub_abstracts = tuple_abstract->elements(); for (const auto &sub_abstract : sub_abstracts) { + MS_EXCEPTION_IF_NULL(sub_abstract); result += GetOutputNumByAbstract(sub_abstract); } return result; @@ -414,7 +408,7 @@ std::vector AnfRuntimeAlgorithm::GetAllOutputByCallNode(const K auto tuple_abstract = node_abstract->cast(); MS_EXCEPTION_IF_NULL(tuple_abstract); const auto &sub_abstracts = tuple_abstract->elements(); - if (sub_abstracts.size() <= output_with_index.second) { + if (GetOutputNumByAbstract(tuple_abstract) <= output_with_index.second) { MS_LOG(EXCEPTION) << "Invalid index:" << output_with_index.second << "for node:" << output_with_index.first->DebugString(); } @@ -423,9 +417,11 @@ std::vector AnfRuntimeAlgorithm::GetAllOutputByCallNode(const K // to count the number of outputs before the target in order to accurately obtain its number. size_t pre_output_num = 0; for (size_t i = 0; i < output_with_index.second; ++i) { + MS_EXCEPTION_IF_NULL(sub_abstracts[i]); pre_output_num += GetOutputNumByAbstract(sub_abstracts[i]); } + MS_EXCEPTION_IF_NULL(sub_abstracts[output_with_index.second]); size_t output_num = GetOutputNumByAbstract(sub_abstracts[output_with_index.second]); std::vector results; for (size_t i = 0; i < output_num; ++i) { diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc index 23ad4b983b3..44f1b6d5725 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc @@ -43,6 +43,9 @@ void ExitActor::FetchInput(OpContext *const context) { if (data_iter != output_branch_data_.end()) { for (auto &output_data : data_iter->second) { MS_EXCEPTION_IF_NULL(output_data.second); + if (output_data.first >= input_device_tensors_.size()) { + MS_LOG(EXCEPTION) << "Invalid from index:" << output_data.first << " for actor:" << GetAID(); + } MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]); output_data.second->data_ = input_device_tensors_[output_data.first]; } diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 0dde2a4cba1..0f35b93e5d7 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -528,6 +528,11 @@ void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set FetchInputNodeByNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (HasAbstractMonad(node)) { + return {}; + } + // The node is divided into the following types: // 1. depend and load. const auto &node_with_index = @@ -586,18 +591,43 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { // 4. One output node. const auto &abstract = real_node->abstract(); - if (abstract == nullptr || - ((!abstract->isa()) && (!abstract->isa()))) { - if (abstract == nullptr) { - MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString(); - } - return {AnfAlgo::VisitKernelWithReturnType(real_node, real_index)}; + if (abstract == nullptr) { + MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString(); + results.emplace_back(AnfAlgo::VisitKernelWithReturnType(real_node, real_index)); + return results; } - // 4. Abstract is Tuple. + // 5 Other. size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); - for (size_t i = 0; i < output_num; ++i) { - results.emplace_back(real_node, i); + if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) { + const auto &get_item_cnode = real_node->cast(); + MS_EXCEPTION_IF_NULL(get_item_cnode); + const auto &get_item_src_node = AnfAlgo::GetTupleGetItemRealInput(get_item_cnode); + size_t get_item_src_index = AnfAlgo::GetTupleGetItemOutIndex(get_item_cnode); + + // Input node of getitm is a make tuple. + if (AnfAlgo::CheckPrimitiveType(get_item_src_node, prim::kPrimMakeTuple)) { + const auto &make_tuple_cnode = get_item_src_node->cast(); + const auto &makt_tuple_inputs = make_tuple_cnode->inputs(); + if (makt_tuple_inputs.size() <= get_item_src_index) { + MS_LOG(EXCEPTION) << "Invalid index:" << get_item_src_index + << " for make tuple node : " << get_item_src_node->DebugString(); + } + const auto &sub_results = FetchInputNodeByNode(makt_tuple_inputs[get_item_src_index + kMakeTupleInputStartPos]); + results.insert(results.end(), sub_results.begin(), sub_results.end()); + } else { + // Input node of getitm is a parameter or make tuple. + auto get_item_src_abstract = get_item_src_node->abstract(); + MS_EXCEPTION_IF_NULL(get_item_src_abstract); + auto real_indexs = FetchRealIndexByAbstract(get_item_src_abstract, get_item_src_index); + (void)std::transform( + real_indexs.begin(), real_indexs.end(), std::back_inserter(results), + [&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); }); + } + } else { + for (size_t i = 0; i < output_num; ++i) { + results.emplace_back(real_node, i); + } } return results; } @@ -1705,7 +1735,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte MS_EXCEPTION_IF_NULL(front_node_with_index.first); // If input come from the output of kernel graph belong the same group, it should not be collected in // the group inputs. - if (HasAbstractMonad(front_node_with_index.first) || + if (HasAbstractMonad(front_node_with_index.first) || HasAbstractMonad(parameter) || kernel_graph_group_info->front_output_nodes_.find(front_node_with_index) != kernel_graph_group_info->front_output_nodes_.end()) { continue;