diff --git a/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc b/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc index c3796ff455c..4c197f0ead0 100644 --- a/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc +++ b/mindspore/ccsrc/plugin/device/cpu/hal/device/cpu_device_address.cc @@ -119,6 +119,20 @@ bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId MS_LOG(WARNING) << "Please check whether need sync data, host size: " << size << ", device size: " << size_; return true; } + + // If the value of host is a scalar type, then the host addr is a temporary address, which will be released after + // the sync ends. Therefore, if the value is less than 16, it needs to be copied. + const size_t kCopySize = 16; + if (size <= kCopySize) { + auto ret = memcpy_s(ptr_, size, host_ptr, size); + if (ret != EOK) { + MS_LOG(ERROR) << "Failed to copy tensor!"; + return false; + } else { + return true; + } + } + // Use the tensor host ptr to set the device ptr. if (from_mem_pool_) { CPUMemoryPool::GetInstance().FreeTensorMem(ptr_); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc index 991c60a01b9..3ad57887336 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_common.cc @@ -80,7 +80,7 @@ bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph, // In control flow, only the parameters of the root funcgraph are in the host data source. const auto &front_node = graph->GetFrontAnfByBackendAnf(node); - bool is_host = ((front_node == nullptr) || host_parameters.empty() || + bool is_host = ((front_node == nullptr) || find(host_parameters.begin(), host_parameters.end(), front_node) != host_parameters.end()); // Judge whether node is internal parameter. diff --git a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc index f103868b2bf..012135f3444 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc @@ -129,6 +129,28 @@ std::vector CollectActors(const ControlActorSetPtr &control_act return actors; } + +std::vector FetchAllMonadNodeByNode(const AnfNodePtr &node) { + if (!node->isa()) { + return {}; + } + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad)) { + return {node}; + } + + std::vector results; + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + const auto &cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (const auto &input : cnode->inputs()) { + MS_EXCEPTION_IF_NULL(input); + const auto &result = FetchAllMonadNodeByNode(input); + results.insert(results.end(), result.begin(), result.end()); + } + } + return results; +} } // namespace ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info, @@ -557,7 +579,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr if (!parser->IsNeedStackControlNode(switch_actor->node_)) { for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) { LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i}, - parser); + graph_compiler_info); } } else { // If the control actor has a corresponding stack actor, the input should be linked to the stack actor. @@ -566,7 +588,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr MS_EXCEPTION_IF_NULL(actor); auto stack_actor = dynamic_cast(actor); MS_EXCEPTION_IF_NULL(stack_actor); - LinkArrowFromStackActor(stack_actor, switch_actor.get(), parser); + LinkArrowFromStackActor(stack_actor, switch_actor.get(), graph_compiler_info); } } @@ -575,7 +597,7 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr if (!parser->IsNeedStackControlNode(gather_actor->node_)) { for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) { LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i}, - parser); + graph_compiler_info); } } else { // If the control actor has a corresponding stack actor, the input should be linked to the stack actor. @@ -584,13 +606,13 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr MS_EXCEPTION_IF_NULL(actor); auto stack_actor = dynamic_cast(actor); MS_EXCEPTION_IF_NULL(stack_actor); - LinkArrowFromStackActor(stack_actor, gather_actor.get(), parser); + LinkArrowFromStackActor(stack_actor, gather_actor.get(), graph_compiler_info); } } for (auto &entrance_actor : control_actor_set->entrance_actors_) { for (const auto &call_node : entrance_actor->call_nodes_) { - LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, parser); + LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, graph_compiler_info); } } @@ -602,28 +624,31 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr auto actor = FetchActor(stack_actor_name); if (actor == nullptr) { for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) { - LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, parser); + LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i}, + graph_compiler_info); } } else { // If the control actor has a corresponding stack actor, the input should be linked to the stack actor. auto stack_actor = dynamic_cast(actor); MS_EXCEPTION_IF_NULL(stack_actor); - LinkArrowFromStackActor(stack_actor, exit_actor.get(), parser); + LinkArrowFromStackActor(stack_actor, exit_actor.get(), graph_compiler_info); } } for (auto &stack_actor : control_actor_set->stack_actors_) { for (size_t i = 0; i < stack_actor->formal_parameters_.size(); ++i) { LinkArrowbyFormalParameter(stack_actor.get(), stack_actor->formal_parameters_[i], {stack_actor->node_, i}, - parser); + graph_compiler_info); } } } void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor, - const ControlNodeParserPtr &parser) { + const GraphCompilerInfo &graph_compiler_info) { MS_EXCEPTION_IF_NULL(stack_actor); MS_EXCEPTION_IF_NULL(to_actor); + MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); + const auto &parser = graph_compiler_info.control_node_parser_; for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) { const auto &formal_parameter = to_actor->formal_parameters_[to_index]; @@ -638,7 +663,7 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor (!AnfAlgo::IsCallNode(from_node)) && (!AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) && to_actor->GetAID().Name().find( parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) { - LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, parser); + LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, graph_compiler_info); continue; } @@ -660,16 +685,18 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor, const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index, - const ControlNodeParserPtr &parser) { + const GraphCompilerInfo &graph_compiler_info) { const auto &from_node = from_node_with_index.first; MS_EXCEPTION_IF_NULL(from_node); if (from_node->isa()) { LinkArrowByValueNode(from_node, to_actor, from_node_with_index.second, to_node_with_index.second); } else if (from_node->isa()) { - LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, parser); + LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, + graph_compiler_info.control_node_parser_); } else if (AnfAlgo::IsCallNode(from_node)) { // Link arrow by call node. - LinkArrowByCallNode(from_node, to_actor, from_node_with_index, to_node_with_index, parser); + LinkArrowByCallNode(from_node, to_actor, from_node_with_index, to_node_with_index, + graph_compiler_info.control_node_parser_); } else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) || AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) { // Link arrow from switch actor. @@ -678,14 +705,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act MS_EXCEPTION_IF_NULL(actor); const auto &switch_actor = dynamic_cast(actor); MS_EXCEPTION_IF_NULL(switch_actor); - - const auto &abstract = from_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - if (abstract->isa()) { - LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second); - } else { - LinkDataArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second); - } + LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second); } else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) { // Link arrow from gather actor const auto &actor_name = GetActorName(from_node); @@ -699,7 +719,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act LinkPartialArrow(gather_actor, to_actor, from_node_with_index.second, to_node_with_index.second); } else if (from_node->isa()) { // Link arrow by kernel. - LinkArrowByKernel(from_node, to_actor, from_node_with_index, to_node_with_index, parser); + LinkArrowByKernel(from_node, to_actor, from_node_with_index, to_node_with_index, graph_compiler_info); } } @@ -828,12 +848,13 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor, const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index, - const ControlNodeParserPtr &parser) { - MS_EXCEPTION_IF_NULL(parser); + const GraphCompilerInfo &graph_compiler_info) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(to_actor); const auto &from_node = from_node_with_index.first; MS_EXCEPTION_IF_NULL(from_node); + MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); + const auto &parser = graph_compiler_info.control_node_parser_; const auto &graph = parser->FetchKernelGraphByFrontNode(from_node); MS_EXCEPTION_IF_NULL(graph); const auto &group_name = parser->FetchGroupNameByKernelGraph(graph); @@ -844,7 +865,7 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc const auto &kernel_with_index = parser->FetchBackendNodeByFrontNode(from_node_with_index); MS_EXCEPTION_IF_NULL(kernel_with_index.first); auto type = FetchKernelTransformType(kernel_with_index.first, graph, {}); - auto from_actor = FetchActor(type, "", kernel_with_index.first, graph); + auto from_actor = FetchActor(type, graph_compiler_info.name_, kernel_with_index.first, graph); MS_EXCEPTION_IF_NULL(from_actor); if (!AnfAlgo::OutputAddrExist(kernel_with_index.first, kernel_with_index.second, false)) { MS_LOG(EXCEPTION) << "Invalid output index:" << kernel_with_index.second @@ -926,10 +947,10 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor const auto &inputs = cnode->inputs(); for (const auto &input : inputs) { MS_EXCEPTION_IF_NULL(input); - if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimUpdateState) || - AnfAlgo::CheckPrimitiveType(input, prim::kPrimDepend) || - AnfAlgo::CheckPrimitiveType(input, prim::kPrimLoad)) { - LinkControlArrowByAutoMonad(from_actor, input, parser); + std::vector monad_nodes = FetchAllMonadNodeByNode(input); + for (const auto &monad_node : monad_nodes) { + MS_EXCEPTION_IF_NULL(monad_node); + LinkControlArrowByAutoMonad(from_actor, monad_node, parser); } } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h index d68e9574677..577ab5de429 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.h @@ -63,20 +63,22 @@ class ControlNodeScheduler { // Link all arrows between control actors. void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkArrowbyFormalParameter(ControlActor *const to_actor, const KernelWithIndex &from_node_with_index, - const KernelWithIndex &to_node_with_index, const ControlNodeParserPtr &parser); + const KernelWithIndex &to_node_with_index, + const GraphCompilerInfo &graph_compiler_info); void LinkArrowByCallNode(const AnfNodePtr &call_node, ControlActor *const to_actor, const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index, const ControlNodeParserPtr &parser); void LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor, const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index, - const ControlNodeParserPtr &parser); + const GraphCompilerInfo &graph_compiler_info); void LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor, const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index, const ControlNodeParserPtr &parser); void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index, size_t to_index); // Link arrow from stack actor to control actor. - void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor, const ControlNodeParserPtr &parser); + void LinkArrowFromStackActor(StackActor *stack_actor, ControlActor *to_actor, + const GraphCompilerInfo &graph_compiler_info); // Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor. void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);