diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index c9aebd27971..1b244d49886 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -307,7 +307,8 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, bool skip_nop_node, - const std::vector &return_types) { + const std::vector &return_types, + abstract::AbstractBasePtr *abstract) { MS_EXCEPTION_IF_NULL(anf_node); if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool { return CheckPrimitiveType(anf_node, prim_type); @@ -320,8 +321,9 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) { - auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode), - GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types); + abstract::AbstractBasePtr abs = nullptr; + auto item_with_index_tmp = VisitKernelWithReturnType( + GetTupleGetItemRealInput(cnode), GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types, &abs); if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) { MS_EXCEPTION_IF_NULL(item_with_index_tmp.first); auto make_tuple = item_with_index_tmp.first->cast(); @@ -334,6 +336,32 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr } return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, skip_nop_node, return_types); } + if (IsCallNode(item_with_index_tmp.first)) { + size_t real_index = item_with_index_tmp.second; + if (abs == nullptr) { + abs = item_with_index_tmp.first->abstract(); + real_index = 0; + } + MS_EXCEPTION_IF_NULL(abs); + if (abs->isa()) { + auto tuple_abstract = abs->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + auto sub_abstracts = tuple_abstract->elements(); + if (sub_abstracts.size() <= GetTupleGetItemOutIndex(cnode)) { + MS_LOG(EXCEPTION) << "Invalid index:" << GetTupleGetItemOutIndex(cnode) + << " for abstract:" << abs->ToString(); + } + for (size_t i = 0; i < GetTupleGetItemOutIndex(cnode); ++i) { + MS_EXCEPTION_IF_NULL(sub_abstracts[i]); + real_index += AnfAlgo::GetOutputNumByAbstract(sub_abstracts[i]); + } + if (abstract != nullptr) { + (*abstract) = sub_abstracts[GetTupleGetItemOutIndex(cnode)]; + MS_EXCEPTION_IF_NULL((*abstract)); + } + return {item_with_index_tmp.first, real_index}; + } + } return item_with_index_tmp; } if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index bc051ed8f36..f39f7aaabed 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -77,10 +77,10 @@ class AnfRuntimeAlgorithm { static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); // get input_anf_node's real kernel by recurse static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); - static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, - bool skip_nop_node = false, - const std::vector &return_types = { - prim::kPrimMakeTuple}); + static KernelWithIndex VisitKernelWithReturnType( + const AnfNodePtr &input_anf_node, size_t output_index, bool skip_nop_node = false, + const std::vector &return_types = {prim::kPrimMakeTuple}, + abstract::AbstractBasePtr *abstract = nullptr); static std::vector GetAllOutput(const AnfNodePtr &node, const std::vector &return_types = {}); static std::vector GetAllOutputWithIndex(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 0573595552d..abd4d17cdcf 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -737,14 +737,13 @@ void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &con continue; } - const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters(); - const auto &iter = front_to_backend_parameters.find({front_node, 0}); - if (iter == front_to_backend_parameters.end() || iter->second.empty()) { + const auto &backend_parameter_with_context = + control_node_parser->FetchBackendParameterWithContextByFrontParameter({front_node, 0}); + if (backend_parameter_with_context.first == nullptr) { MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << AnfAlgo::GetNodeDebugString(front_node); } - const auto &node_with_context = iter->second.begin(); - const auto &backend_node = node_with_context->first; - const auto &device_context = node_with_context->second; + const auto &backend_node = backend_parameter_with_context.first; + const auto &device_context = backend_parameter_with_context.second; MS_EXCEPTION_IF_NULL(backend_node); MS_EXCEPTION_IF_NULL(device_context); diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 415fb321bf3..d1d37f1bdf3 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -369,7 +369,8 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id); MS_EXCEPTION_IF_NULL(address); - MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address; + MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address + << " size:" << tensor_size; AnfAlgo::SetOutputAddr(address, front_node_with_index.second, front_node.get()); UpdateRefCount(address.get(), true); } @@ -1231,6 +1232,25 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c return nullptr; } +NodeWithContext ControlNodeParser::FetchBackendParameterWithContextByFrontParameter( + const KernelWithIndex &front_parameter_with_index) { + const auto &iter = front_to_backend_parameters_.find(front_parameter_with_index); + if (iter == front_to_backend_parameters_.end()) { + return {}; + } + + for (const auto &node_with_context : iter->second) { + MS_EXCEPTION_IF_NULL(node_with_context.first); + if (AnfAlgo::GetOutputTensorMemSize(node_with_context.first, 0) != 0) { + return node_with_context; + } + MS_LOG(WARNING) << "Backend node:" << node_with_context.first->DebugString() + << " for front node:" << front_parameter_with_index.first->DebugString() + << " index:" << front_parameter_with_index.second << " output size is 0."; + } + return {}; +} + void ControlNodeParser::FetchFrontValueNode(const std::vector &control_nodes, const DeviceContext *const default_context) { MS_EXCEPTION_IF_NULL(default_context); @@ -1241,11 +1261,12 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr continue; } - const auto &iter = front_to_backend_parameters_.find(real_parameter_with_index); - if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) { - (void)front_value_nodes_.emplace(real_parameter_with_index, iter->second.begin()->second); - CreateDeviceTensorForValueNode(real_parameter_with_index, iter->second.begin()->first, - iter->second.begin()->second); + const auto &backend_node_with_context = + FetchBackendParameterWithContextByFrontParameter(real_parameter_with_index); + if (backend_node_with_context.first != nullptr) { + (void)front_value_nodes_.emplace(real_parameter_with_index, backend_node_with_context.second); + CreateDeviceTensorForValueNode(real_parameter_with_index, backend_node_with_context.first, + backend_node_with_context.second); } else { (void)front_value_nodes_.emplace(real_parameter_with_index, default_context); CreateDeviceTensorForFrontNode(real_parameter_with_index, default_context); @@ -1253,19 +1274,6 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr } } - // If the output of funcgraph is a value node, it will eventually be sent to the kernel as a real parameter. - // These the value nodes also need to create a device address. - for (const auto &front_to_backend_parameters : front_to_backend_parameters_) { - const auto &front_node = front_to_backend_parameters.first.first; - MS_EXCEPTION_IF_NULL(front_node); - if (IsFrontValueNode(front_to_backend_parameters.first) && (!front_to_backend_parameters.second.empty())) { - const auto &backend_parameter = front_to_backend_parameters.second.begin()->first; - const auto &device_context = front_to_backend_parameters.second.begin()->second; - CreateDeviceTensorForValueNode(front_to_backend_parameters.first, backend_parameter, device_context); - (void)front_value_nodes_.emplace(front_to_backend_parameters.first, device_context); - } - } - // Create device tensors for those value nodes which direct return by a return node. for (const auto &control_node : control_nodes) { if ((!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && (!AnfAlgo::IsCallNode(control_node))) { diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index ee4d03b529c..4b60da987c5 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -66,8 +66,8 @@ constexpr size_t kMakeCSRTensorInputNum = 4; const char kEntranceActorNameSuffix[] = "_EntranceActor"; const char kExitActorNameSuffix[] = "_ExitActor"; const char kStackActorNameSuffix[] = "_StackActor"; - -using FrontToBackendNodeWithContext = std::map>>; +using NodeWithContext = std::pair; +using FrontToBackendNodeWithContext = std::map>; using FrontToBackendKernelWithContext = std::map>; using FuncGraphToKernelGraphGroup = mindspore::HashMap>>; using HostParameterToWeight = std::map>; @@ -145,6 +145,7 @@ class ControlNodeParser { KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index); FuncGraphPtr FetchFuncGraphByKernelGraph(const KernelGraph *const graph); std::string FetchGroupNameByKernelGraph(const KernelGraphPtr &graph); + NodeWithContext FetchBackendParameterWithContextByFrontParameter(const KernelWithIndex &front_parameter_with_index); private: friend class GraphScheduler; diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index 47ef8ea3119..7a0ac679379 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -1418,6 +1418,10 @@ void ControlNodeScheduler::LinkBranchIDArrow(ControlActor *const from_actor, Con bool ControlNodeScheduler::CheckActorValid(const ActorSet *actor_set) const { MS_EXCEPTION_IF_NULL(actor_set); + if (actor_set->control_actors_ == nullptr) { + return true; + } + for (const auto &kernel_actor : actor_set->kernel_actors_) { std::string exit_actor_name = ""; for (const auto arrow : kernel_actor->output_data_arrows_) { diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 803fb0e520f..2e9edb3f295 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -646,22 +646,16 @@ std::vector GraphScheduler::BuildDataSourceActor(const Graph } } - MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); - const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_; + const auto parser = graph_compiler_info.control_node_parser_; + MS_EXCEPTION_IF_NULL(parser); // Initialize the parameter in the control node, first get all the front parameters in the control node, then find - // the corresponding backend parameter from the map, and insert it into the host data source actor - const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters(); + // the corresponding backend parameter from the map, and insert it into the host data source actor. + const auto &control_node_parameters = parser->control_node_parameters(); for (const auto ¶meter : control_node_parameters) { if (IsPersistentDeviceTensor(parameter)) { continue; } - - auto backend_iter = front_to_backend_parameter.find({parameter, 0}); - if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) { - MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter); - } - if (host_queue_ds_actor == nullptr) { auto actor_name = graph_compiler_info.name_ + "_HostDSActor"; MS_LOG(INFO) << "Create host queue data source actor: " << actor_name; @@ -675,15 +669,18 @@ std::vector GraphScheduler::BuildDataSourceActor(const Graph if (node_map.find(parameter) != node_map.end()) { continue; } - const auto &backend_node = backend_iter->second.begin()->first; + const auto &backend_parameter_with_context = + parser->FetchBackendParameterWithContextByFrontParameter({parameter, 0}); + const auto &backend_node = backend_parameter_with_context.first; + MS_EXCEPTION_IF_NULL(backend_node); auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node); if (iter != host_queue_ds_actor->data_nodes_.end()) { (void)node_map.emplace(parameter, iter - host_queue_ds_actor->data_nodes_.begin()); } else { (void)node_map.emplace(parameter, host_queue_ds_actor->data_nodes_.size()); - (void)node_map.emplace(backend_iter->second.begin()->first, host_queue_ds_actor->data_nodes_.size()); - (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first); - (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second); + (void)node_map.emplace(backend_node, host_queue_ds_actor->data_nodes_.size()); + (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_node); + (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_parameter_with_context.second); } } @@ -1940,14 +1937,13 @@ void GraphScheduler::PersistDeviceTensorForControlNode(const GraphCompilerInfo & if ((!IsPersistentDeviceTensor(input_node)) || (!parser->IsRootGraphParameter(input_node))) { continue; } - const auto &front_to_backend_parameters = parser->front_to_backend_parameters(); - const auto &iter = front_to_backend_parameters.find({input_node, 0}); - if (iter == front_to_backend_parameters.end() || iter->second.empty()) { + const auto &backend_parameter_with_context = + parser->FetchBackendParameterWithContextByFrontParameter({input_node, 0}); + if (backend_parameter_with_context.first == nullptr) { MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << input_node->DebugString(); } - const auto &node_with_context = iter->second.begin(); - const auto &backend_node = node_with_context->first; - const auto &device_context = node_with_context->second; + const auto &backend_node = backend_parameter_with_context.first; + const auto &device_context = backend_parameter_with_context.second; MS_EXCEPTION_IF_NULL(backend_node); MS_EXCEPTION_IF_NULL(device_context); if (!DeviceTensorStore::GetInstance().Fetch(input_node.get()).empty()) {