From b5a9588d10d9a0032d4b275c7dddce93962c7309 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Tue, 9 Nov 2021 12:58:19 +0800 Subject: [PATCH] delete useless interface in control node parser --- .../runtime/framework/actor/actor_common.cc | 10 - .../runtime/framework/actor/actor_common.h | 4 - .../runtime/framework/actor/actor_dump.cc | 170 +- .../runtime/framework/actor/actor_dump.h | 8 +- .../actor/control_flow/exit_actor.cc | 4 + .../actor/control_flow/gather_actor.cc | 1 + .../framework/actor/data_prepare_actor.cc | 14 +- .../framework/actor/data_prepare_actor.h | 8 +- .../runtime/framework/control_node_parser.cc | 1801 +++++------------ .../runtime/framework/control_node_parser.h | 260 +-- .../framework/control_node_scheduler.cc | 5 +- .../runtime/framework/graph_scheduler.cc | 59 +- mindspore/ccsrc/vm/backend.cc | 17 +- mindspore/ccsrc/vm/backend.h | 3 + 14 files changed, 879 insertions(+), 1485 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc index 09942783166..2230a249d0c 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc @@ -132,16 +132,6 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) { return false; } -bool IsGatherActor(const AnfNodePtr &front_node, - const std::unordered_map *> &actor_name_to_actor) { - MS_EXCEPTION_IF_NULL(front_node); - if (front_node->isa() && (!AnfAlgo::IsParameterWeight(front_node->cast())) && - (front_node->func_graph() != nullptr) && (actor_name_to_actor.count(front_node->func_graph()->ToString()) > 0)) { - return true; - } - return false; -} - bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) { MS_EXCEPTION_IF_NULL(dst_device_tensor); MS_EXCEPTION_IF_NULL(src_device_tensor); diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.h b/mindspore/ccsrc/runtime/framework/actor/actor_common.h index d4a1ce868e1..d555e6e8491 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.h +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.h @@ -133,10 +133,6 @@ bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph); // Judge whether the device tensor of the node is persistent or not. bool IsPersistentDeviceTensor(const AnfNodePtr &node); -// Judge whether the front node is in a gather actor. -bool IsGatherActor(const AnfNodePtr &front_node, - const std::unordered_map *> &actor_name_to_actor); - // Copy data from src_device_tensor to dst_device_tensor. bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor); diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc b/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc index 25740df2a6e..0e78fa1ebfd 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc @@ -177,14 +177,157 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) { ofs << "\n"; } -void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) { - MS_EXCEPTION_IF_NULL(actor); - ofs << "\tactor_name:" << actor->GetAID().Name() << '\n'; +void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) { + const auto &output_data_arrows = actor->output_data_arrows(); + if (output_data_arrows.size() > 0) { + ofs << "\t\t\toutput_data_arrows:" << output_data_arrows.size() << "\n "; + for (const auto &data_arrow : output_data_arrows) { + MS_EXCEPTION_IF_NULL(data_arrow); + ofs << "\t\t\t\tfrom_output_index:" << data_arrow->from_output_index_ + << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_ + << "\n"; + } + } + + const auto &output_control_arrows = actor->output_control_arrows(); + if (output_control_arrows.size() > 0) { + ofs << "\t\t\toutput_control_arrows:" << output_control_arrows.size() << "\n "; + for (const auto &aid : output_control_arrows) { + ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n"; + } + } + + const auto &output_partial_arrows = actor->output_partial_arrows(); + if (output_partial_arrows.size() > 0) { + ofs << "\t\t\toutput_partial_arrows:" << output_partial_arrows.size() << "\n "; + for (const auto &partial_arrow : output_partial_arrows) { + MS_EXCEPTION_IF_NULL(partial_arrow); + ofs << "\t\t\t\tfrom_output_index:" << partial_arrow->from_output_index_ + << "\tto_actor_name:" << partial_arrow->to_op_id_.Name() + << "\tto_input_index:" << partial_arrow->to_input_index_ << "\n"; + } + } + + const auto &output_branch_id_arrows = actor->output_branch_id_arrows(); + if (output_branch_id_arrows.size() > 0) { + ofs << "\t\t\toutput_branch_id_arrows:" << output_branch_id_arrows.size() << "\n "; + for (const auto &aid : output_branch_id_arrows) { + ofs << "\t\t\t\tto_actor_name:" << aid.Name() << "\n"; + } + } } void DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) { MS_EXCEPTION_IF_NULL(actor); - ofs << "\tactor_name:" << actor->GetAID().Name() << '\n'; + ofs << "\t\ttactor_name:" << actor->GetAID().Name() << '\n'; + DumpControlActor(actor, ofs); +} + +void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) { + MS_EXCEPTION_IF_NULL(actor); + ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n'; + DumpControlActor(actor, ofs); + + const auto &output_data_with_branch_id_arrows = actor->output_data_with_branch_id_arrows(); + if (output_data_with_branch_id_arrows.size() > 0) { + ofs << "\t\t\toutput_data_with_branch_id_arrows:" << output_data_with_branch_id_arrows.size() << "\n "; + for (const auto &output_data_with_branch_id_arrow : output_data_with_branch_id_arrows) { + ofs << "\t\t\t\tbranch funcgraph:" << output_data_with_branch_id_arrow.first->ToString() << "\n"; + for (const auto &arrow : output_data_with_branch_id_arrow.second) { + ofs << "\t\t\t\t\tto actor:" << arrow << "\n"; + } + } + } +} + +void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) { + MS_EXCEPTION_IF_NULL(actor); + ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n'; + DumpControlActor(actor, ofs); +} + +void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) { + MS_EXCEPTION_IF_NULL(actor); + ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n'; + DumpControlActor(actor, ofs); + + const auto &output_branch_data_arrows = actor->output_branch_data_arrows(); + if (output_branch_data_arrows.size() > 0) { + ofs << "\t\t\toutput_branch_data_arrows:" << output_branch_data_arrows.size() << "\n "; + for (const auto &output_branch_data_arrow : output_branch_data_arrows) { + ofs << "\t\t\t\tbranch id:" << output_branch_data_arrow.first << "\n"; + for (const auto &arrow : output_branch_data_arrow.second) { + MS_EXCEPTION_IF_NULL(arrow); + ofs << "\t\t\t\t\tfrom_output_index:" << arrow->from_output_index_ + << "\tto_actor_name:" << arrow->to_op_id_.Name() << "\tto_input_index:" << arrow->to_input_index_ << "\n"; + } + } + } + + const auto &output_branch_partial_arrows = actor->output_branch_partial_arrows(); + if (output_branch_partial_arrows.size() > 0) { + ofs << "\t\t\toutput_branch_partial_arrows:" << output_branch_partial_arrows.size() << "\n "; + for (const auto &output_branch_partial_arrow : output_branch_partial_arrows) { + ofs << "\t\t\t\tbranch id:" << output_branch_partial_arrow.first << "\n"; + for (const auto &arrow : output_branch_partial_arrow.second) { + MS_EXCEPTION_IF_NULL(arrow); + ofs << "\t\t\t\t\tfrom_output_index:" << arrow->from_output_index_ + << "\tto_actor_name:" << arrow->to_op_id_.Name() << "\tto_input_index:" << arrow->to_input_index_ << "\n"; + } + } + } + + const auto &output_branch_control_arrows = actor->output_branch_control_arrows(); + if (output_branch_control_arrows.size() > 0) { + ofs << "\t\t\toutput_branch_control_arrows:" << output_branch_control_arrows.size() << "\n "; + for (const auto &output_branch_control_arrow : output_branch_control_arrows) { + ofs << "\t\t\t\tbranch id:" << output_branch_control_arrow.first << "\n"; + for (const auto &arrow : output_branch_control_arrow.second) { + ofs << "\t\t\t\t\tto actor:" << arrow << "\n"; + } + } + } +} + +void DumpStackActor(const StackActor *actor, std::ofstream &ofs) { + MS_EXCEPTION_IF_NULL(actor); + ofs << "\t\tactor_name:" << actor->GetAID().Name() << '\n'; + DumpControlActor(actor, ofs); +} + +void DumpSwitchActors(const std::vector &actors, std::ofstream &ofs) { + ofs << "\n\n\t[Switch actors:" << actors.size() << "]\n"; + for (const auto &switch_actor : actors) { + DumpSwitchActor(switch_actor.get(), ofs); + } +} + +void DumpGatherActors(const std::vector &actors, std::ofstream &ofs) { + ofs << "\n\n\t[Gather actors:" << actors.size() << "]\n"; + for (const auto &gather_actor : actors) { + DumpGatherActor(gather_actor.get(), ofs); + } +} + +void DumpEntranceActors(const std::vector &actors, std::ofstream &ofs) { + ofs << "\n\n\t[Entrance actors:" << actors.size() << "]\n"; + for (const auto &entrance_actor : actors) { + DumpEntranceActor(entrance_actor.get(), ofs); + } +} + +void DumpExitActors(const std::vector &actors, std::ofstream &ofs) { + ofs << "\n\n\t[Exit actors:" << actors.size() << "]\n"; + for (const auto &exit_actor : actors) { + DumpExitActor(exit_actor.get(), ofs); + } +} + +void DumpStackActors(const std::vector &actors, std::ofstream &ofs) { + ofs << "\n\n\t[Stack actors:" << actors.size() << "]\n"; + for (const auto &stack_actor : actors) { + DumpStackActor(stack_actor.get(), ofs); + } } } // namespace @@ -281,18 +424,17 @@ void DumpCopyActors(const std::vector &actors, std::ofstream &ofs) } } -void DumpGatherActors(const std::vector &actors, std::ofstream &ofs) { - ofs << "\n\n[Gather actors:" << actors.size() << "]\n"; - for (const auto &gather_actor : actors) { - DumpGatherActor(gather_actor.get(), ofs); +void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs) { + ofs << "\n\n[Control actors]\n"; + if (control_actor_set == nullptr) { + return; } -} -void DumpSwitchActors(const std::vector &actors, std::ofstream &ofs) { - ofs << "\n\n[Switch actors:" << actors.size() << "]\n"; - for (const auto &switch_actor : actors) { - DumpSwitchActor(switch_actor.get(), ofs); - } + DumpSwitchActors(control_actor_set->switch_actors_, ofs); + DumpGatherActors(control_actor_set->gather_actors_, ofs); + DumpEntranceActors(control_actor_set->entrance_actors_, ofs); + DumpExitActors(control_actor_set->exit_actors_, ofs); + DumpStackActors(control_actor_set->stack_actors_, ofs); } } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_dump.h b/mindspore/ccsrc/runtime/framework/actor/actor_dump.h index dd5eda389c9..66c465370bf 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_dump.h +++ b/mindspore/ccsrc/runtime/framework/actor/actor_dump.h @@ -30,8 +30,13 @@ #include "runtime/framework/actor/super_kernel_actor.h" #include "runtime/framework/actor/output_actor.h" #include "runtime/framework/actor/copy_actor.h" +#include "runtime/framework/actor/control_flow/control_actor.h" #include "runtime/framework/actor/control_flow/switch_actor.h" #include "runtime/framework/actor/control_flow/gather_actor.h" +#include "runtime/framework/actor/control_flow/entrance_actor.h" +#include "runtime/framework/actor/control_flow/exit_actor.h" +#include "runtime/framework/actor/control_flow/stack_actor.h" +#include "runtime/framework/control_node_scheduler.h" namespace mindspore { namespace runtime { @@ -43,8 +48,7 @@ void DumpKernelActors(const std::vector &actors, std::ofstream & void DumpSuperKernelActors(const std::vector &actors, std::ofstream &ofs); void DumpNoInputKernelActors(const std::vector &actors, std::ofstream &ofs); void DumpCopyActors(const std::vector &actors, std::ofstream &ofs); -void DumpGatherActors(const std::vector &actors, std::ofstream &ofs); -void DumpSwitchActors(const std::vector &actors, std::ofstream &ofs); +void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs); } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc index a6dfdf99d1b..818f803d1d5 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc @@ -75,6 +75,10 @@ void ExitActor::SendOutput(OpContext *const context) { } void ExitActor::CopyDeviceAddress() { + // If node is not empty, it is the exit of funcgraph, no need to create device address. + if (node_ != nullptr) { + return; + } std::vector new_device_tensors; for (size_t i = 0; i < input_device_tensors_.size(); ++i) { auto input_device_tensor = input_device_tensors_[i]; diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc index 99533f1a8cd..90a0b6be99a 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.cc @@ -28,6 +28,7 @@ void GatherActor::FetchInput(OpContext *const context) { ControlActor::FetchInput(context); output_partial_ = input_partials_[0]; + MS_EXCEPTION_IF_NULL(output_partial_.first); // Put other real parameter in partial. for (const auto &device_tensor : input_device_tensors_) { diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 63772a447d6..d971f433f1f 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -459,10 +459,10 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, } // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor. -void DataPrepareActor::PrepareDataForControlWeightNode( - const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context, - const std::unordered_map> &host_parameter_to_weights, - OpContext *const context) { +void DataPrepareActor::PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, + const TensorPtr &tensor, const DeviceContext *device_context, + const HostParameterToWeight &host_parameter_to_weights, + OpContext *const context) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(front_node); MS_EXCEPTION_IF_NULL(tensor); @@ -516,12 +516,12 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP if (IsPersistentDeviceTensor(input_node)) { const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters(); const auto &iter = front_to_backend_parameters.find(input_node); - if (iter == front_to_backend_parameters.end()) { + if (iter == front_to_backend_parameters.end() || iter->second.empty()) { MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << AnfAlgo::GetNodeDebugString(input_node); } - const auto &node_with_context = iter->second; - PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second, + const auto &node_with_context = iter->second.begin(); + PrepareDataForControlWeightNode(node_with_context->first, input_node, input_tensor, node_with_context->second, control_node_parser->host_parameter_to_weights(), context); } } diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h index 428089c19a8..f3703daaa8a 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h @@ -93,10 +93,10 @@ class DataPrepareActor : public DebugAwareActor { std::vector *const host_tensors, OpContext *const context); // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor. - void PrepareDataForControlWeightNode( - const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context, - const std::unordered_map> &host_parameter_to_weights, - OpContext *const context); + void PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, + const DeviceContext *device_context, + const HostParameterToWeight &host_parameter_to_weights, + OpContext *const context); const GraphCompilerInfo *graph_compiler_info_; GraphExecutionStrategy strategy_; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index e82adc3c671..cd20f277999 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -15,23 +15,72 @@ */ #include "runtime/framework/control_node_parser.h" -#include "runtime/framework/actor/control_flow/switch_actor.h" -#include "runtime/framework/actor/control_flow/gather_actor.h" #include "abstract/utils.h" #include "ir/tensor.h" namespace mindspore { namespace runtime { namespace { -using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; +// Get all the real parameters corresponding to node. +void FetchRealParameterByNode(const KernelWithIndex &node, std::set *real_parameters, + std::set *invalid_call_nodes) { + const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second); + if (node_with_index.first->isa() || node_with_index.first->isa()) { + // If node is a valuenode or parameter, the real parameter is itself. + real_parameters->emplace(node_with_index); + } else if (AnfAlgo::IsCallNode(node_with_index.first)) { + // If node is a call node, the real parameters are the outputs of funcgraph the node called. + if (invalid_call_nodes->find(node_with_index) != invalid_call_nodes->end()) { + return; + } + invalid_call_nodes->emplace(node_with_index); + const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node_with_index.first); + for (const auto &func_graph : func_graphs) { + MS_EXCEPTION_IF_NULL(func_graph); + FetchRealParameterByNode({func_graph->output(), node_with_index.second}, real_parameters, invalid_call_nodes); + } + } else if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) { + // If node is a maketuple node, the real parameters are its total inputs. + const auto &make_tuple_cnode = node_with_index.first->cast(); + const auto &make_tuple_inputs = make_tuple_cnode->inputs(); + if (make_tuple_inputs.size() <= node_with_index.second) { + MS_LOG(EXCEPTION) << "Invalid index:" << node_with_index.second + << "for tuple node:" << node_with_index.first->DebugString(); + } + } else if (AnfAlgo::CheckPrimitiveType(node.first, prim::kPrimSwitch)) { + // If node is a switch node, the real parameters are its both true and false branches. + const auto cnode = node_with_index.first->cast(); + const auto inputs = cnode->inputs(); + for (size_t i = kSwitchTrueBranchPos; i < inputs.size(); ++i) { + FetchRealParameterByNode({inputs[i], 0}, real_parameters, invalid_call_nodes); + } + } else if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitchLayer)) { + // If node is a switchlyaer node, the real parameters are its total branches. + const auto &switch_layer_cnode = node_with_index.first->cast(); + const auto &switch_layer_inputs = switch_layer_cnode->inputs(); + if (switch_layer_inputs.size() != kSwitchLayerInputNum || + (!AnfAlgo::CheckPrimitiveType(switch_layer_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple))) { + MS_LOG(EXCEPTION) << "Invalid switch layer node:" << switch_layer_cnode->DebugString(); + } + const auto &make_tuple_cnode = switch_layer_inputs[kSwitchLayerBranchPos]->cast(); + const auto &make_tuple_inputs = make_tuple_cnode->inputs(); + for (size_t i = kSwitchTrueBranchPos; i < make_tuple_inputs.size(); ++i) { + FetchRealParameterByNode({make_tuple_inputs[i], 0}, real_parameters, invalid_call_nodes); + } + } else { + // If node is a kernel, the real parameter is itself. + real_parameters->emplace(node_with_index); + } +} + // Fetch all the weight parameters related to node. It runs like this: // if we have a map like {{a, {b, c}}, {b, {d, e}}}, final we will get {{a, {b, c, d, e}}, {b, {c, d}}}. -void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector *dest_nodes, - const std::unordered_map> &front_to_front_weight) { - if (find((*dest_nodes).begin(), (*dest_nodes).end(), node) != (*dest_nodes).end()) { +void FetchWeightbyHostParameter(const AnfNodePtr &node, std::set *dest_nodes, + const HostParameterToWeight &front_to_front_weight) { + if (dest_nodes->find(node) != dest_nodes->end()) { return; } - (void)((*dest_nodes).emplace_back(node)); + dest_nodes->emplace(node); if (front_to_front_weight.find(node) == front_to_front_weight.end()) { return; } @@ -42,364 +91,6 @@ void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector } } -// Check whether the input is a valid parameter. -bool CheckValidFuncGraphInput(const AnfNodePtr &node) { - if (HasAbstractMonad(node)) { - return false; - } else if (node->isa()) { - return !HasAbstractRef(node); - } - return true; -} - -// Get the funcgraph in partial node. -FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &partial_inputs = node->cast()->inputs(); - return GetValueNode(partial_inputs[1]); -} - -// Get the relationship between funcgraph and parameters in the switch node. -void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParameter *graph_to_real_parameters) { - const auto &switch_cnode = switch_node->cast(); - const auto &switch_inputs = switch_cnode->inputs(); - if (switch_inputs.size() != kSwitchInputNum) { - MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(switch_node); - } - - for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { - const auto &partial_node = switch_inputs[i]; - if (IsValueNode(partial_node)) { - continue; - } - const auto &func_graph = GetFuncGraphFromPartial(partial_node); - std::vector parameters; - const auto &partial_inputs = partial_node->cast()->inputs(); - for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) { - if (CheckValidFuncGraphInput(partial_inputs[j])) { - (void)parameters.emplace_back(partial_inputs[j]); - } - } - (void)((*graph_to_real_parameters)[func_graph].emplace_back(parameters)); - } -} - -// Get the corresponding relationship between funcgraph and parameters in the switch layer node. -void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const std::vector &call_inputs, - FuncGraphToParameter *graph_to_real_parameters) { - const auto &switch_layer_cnode = switch_layer_node->cast(); - const auto &switch_layer_inputs = switch_layer_cnode->inputs(); - - if (switch_layer_inputs.size() != kSwitchLayerInputNum) { - MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(switch_layer_node); - } - - auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast()->inputs(); - - // Get the parameter corresponding to each funcgraph in make tuple. - for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { - if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) { - // Tuple branch is a partial node. - const auto &func_graph = GetFuncGraphFromPartial(tuple_inputs[i]); - std::vector parameters; - const auto &partial_inputs = tuple_inputs[i]->cast()->inputs(); - - // Get inputs in partial node. - for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) { - if (CheckValidFuncGraphInput(partial_inputs[j])) { - (void)parameters.emplace_back(partial_inputs[j]); - } - } - - // Get inputs in call node. - for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) { - if (CheckValidFuncGraphInput(call_inputs[j])) { - (void)parameters.emplace_back(call_inputs[j]); - } - } - (void)((*graph_to_real_parameters)[func_graph].emplace_back(parameters)); - } else if (tuple_inputs[i]->isa() && IsValueNode(tuple_inputs[i])) { - // Tuple branch is a call node. - const auto &func_graph = GetValueNode(tuple_inputs[i]); - std::vector parameters; - - // Get inputs in call node. - for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) { - if (CheckValidFuncGraphInput(call_inputs[j])) { - (void)parameters.emplace_back(call_inputs[j]); - } - } - - (void)(*graph_to_real_parameters)[func_graph].emplace_back(parameters); - } - } -} - -// Create a device tensor for the front node. -// Get the output format and select kernel build info from the backend node corresponding to the front node to -// create the device address. -void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, - const DeviceContext *device_context) { - MS_EXCEPTION_IF_NULL(device_context); - - const auto &node_value = front_node->cast()->value(); - if (!node_value->isa()) { - return; - } - - size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(backend_node, 0); - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(backend_node, 0); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(backend_node, 0); - } - - if (front_node->kernel_info() == nullptr) { - front_node->set_kernel_info(std::make_shared()); - } - - // Get the select kernel build info. - auto kernel_info = dynamic_cast(backend_node->kernel_info()); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->GetMutableSelectKernelBuildInfo(); - MS_EXCEPTION_IF_NULL(build_info); - AnfAlgo::SetSelectKernelBuildInfo(build_info, front_node.get()); - - // Create device tensor. - std::string output_format = AnfAlgo::GetOutputFormat(backend_node, 0); - device::DeviceAddressPtr address = - device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id); - MS_EXCEPTION_IF_NULL(address); - MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(front_node) << " addr:" << address; - AnfAlgo::SetOutputAddr(address, 0, front_node.get()); -} - -// Create a device tensor for front parameter. -// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter, there is no -// corresponding backend node for this parameter, so a device tensor needs to be created for it. -void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceContext *device_context) { - MS_EXCEPTION_IF_NULL(device_context); - - TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0); - - if (node->kernel_info() == nullptr) { - auto kernel_info = std::make_shared(); - std::shared_ptr builder = std::make_shared(); - builder->SetOutputsFormat({kOpFormat_DEFAULT}); - builder->SetOutputsDeviceType({type_id}); - kernel_info->set_select_kernel_build_info(builder->Build()); - node->set_kernel_info(kernel_info); - } - size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0); - - // Create device tensor. - device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id); - MS_EXCEPTION_IF_NULL(address); - MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address; - AnfAlgo::SetOutputAddr(address, 0, node.get()); -} - -// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding -// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the -// front_node. -std::pair FetchBackendNodeByFrontNode( - const AnfNodePtr &front_node, - const std::unordered_map> &real_to_formal_front_parameters, - const std::unordered_map> &formal_to_real_front_parameters, - const std::unordered_map> &front_to_backend_parameter, - std::set *invalid_node) { - // Check whether the front_node has been looked for. - if ((*invalid_node).find(front_node) != (*invalid_node).end()) { - return std::pair(); - } - (void)(*invalid_node).insert(front_node); - - const auto front_to_backend_iter = front_to_backend_parameter.find(front_node); - if (front_to_backend_iter != front_to_backend_parameter.end()) { - return front_to_backend_iter->second; - } - - const auto &real_to_formal_iter = real_to_formal_front_parameters.find(front_node); - if (real_to_formal_iter == real_to_formal_front_parameters.end()) { - return std::pair(); - } - for (const auto &next_node : real_to_formal_iter->second) { - auto banckend_node = - FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters, - front_to_backend_parameter, invalid_node); - if (banckend_node.first != nullptr) { - return banckend_node; - } - } - - const auto &formal_to_real_iter = formal_to_real_front_parameters.find(front_node); - if (formal_to_real_iter == formal_to_real_front_parameters.end()) { - return std::pair(); - } - for (const auto &next_node : formal_to_real_iter->second) { - auto banckend_node = - FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters, - front_to_backend_parameter, invalid_node); - if (banckend_node.first != nullptr) { - return banckend_node; - } - } - return std::pair(); -} - -// Fetch all backend input nodes by parameter for gather actor. -std::vector FetchInputNodeByParameter(const AnfNodePtr ¶meter, - const std::vector &host_ds_parameters, - std::set *invalid_inputs, - const FuncGraphToParameter &graph_to_real_parameters) { - std::vector input_nodes; - - // If the node has been collected, skip it. - if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) { - return input_nodes; - } - - // Record the node which has been collected. - (void)(*invalid_inputs).insert(parameter); - - // If the parameter node is a parameter of host data source actor, return it. - if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) { - (void)input_nodes.emplace_back(parameter); - return input_nodes; - } - - // Check the parameter which send to its funcgraph. - const auto &func_graph = parameter->func_graph(); - if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) { - return input_nodes; - } - - std::vector self_inputs; - for (const auto &input : func_graph->get_inputs()) { - // Monad input need not send to funcgraph. - if (HasAbstractMonad(input) || HasAbstractRef(input)) { - continue; - } - (void)self_inputs.emplace_back(input); - } - - const auto iter = find(self_inputs.begin(), self_inputs.end(), parameter); - if (iter == self_inputs.end()) { - MS_LOG(EXCEPTION) << "Cannot find parameter node:" << AnfAlgo::GetNodeDebugString(parameter); - } - size_t pos = iter - self_inputs.begin(); - - for (const auto parameters : graph_to_real_parameters.at(func_graph)) { - if (parameters.size() != self_inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid input num:" << parameters.size() << " and:" << self_inputs.size() - << " for func_graph:" << func_graph->ToString(); - } - const auto input = parameters[pos]; - if (input->isa()) { - (void)input_nodes.emplace_back(input); - } else if (input->isa()) { - // If input is a parameter, you need to find its input recursively. - auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters); - (void)input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end()); - } - } - return input_nodes; -} - -// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph -// called by the call node. -std::vector FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector *call_nodes) { - std::vector outputs; - const auto &output = func_graph->output(); - const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0, false, {prim::kPrimTupleGetItem}); - if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) { - return outputs; - } - if (!AnfAlgo::IsCallNode(real_output.first)) { - outputs.push_back(real_output.first); - return outputs; - } - - (*call_nodes).push_back(real_output.first); - std::vector func_graphs = FetchFuncGraphbyCallNode(real_output.first); - for (const auto &graph : func_graphs) { - auto single_outputs = FetchFuncGraphOutput(graph, call_nodes); - (void)outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end()); - } - return outputs; -} -std::vector FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set *call_nodes, - std::set *switch_nodes); - -// Recursive interface, get all possible output nodes of call node. -std::vector FetchOutputByCallNode(const AnfNodePtr &call_node, std::set *call_nodes, - std::set *switch_nodes) { - std::vector outputs; - if ((*call_nodes).find(call_node) != (*call_nodes).end()) { - return outputs; - } - (void)((*call_nodes).insert(call_node)); - - const auto func_graphs = FetchFuncGraphbyCallNode(call_node); - - for (const auto func_graph : func_graphs) { - std::vector sub_call_nodes; - const std::vector graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes); - for (const auto &graph_output : graph_outputs) { - if (graph_output->isa()) { - outputs.push_back(graph_output); - } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) { - const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes); - (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); - } else if (AnfAlgo::IsCallNode(graph_output)) { - const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes); - (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); - } else if (graph_output->isa()) { - (void)outputs.emplace_back(graph_output); - } else if (graph_output->isa()) { - outputs.push_back(graph_output); - } else { - MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output); - } - } - } - - return outputs; -} - -// Recursive interface, get all possible output nodes of switch node. -std::vector FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set *call_nodes, - std::set *switch_nodes) { - std::vector outputs; - if ((*switch_nodes).find(switch_node) != (*switch_nodes).end()) { - return outputs; - } - (void)((*switch_nodes).insert(switch_node)); - - if (!switch_node->isa()) { - MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node); - } - const auto &inputs = switch_node->cast()->inputs(); - if (inputs.size() != kSwitchInputNum) { - MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node); - } - - for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { - if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimPartial)) { - continue; - } else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { - const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes); - (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); - } else if (AnfAlgo::IsCallNode(inputs[i])) { - const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes); - (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); - } else { - (void)outputs.emplace_back(inputs[i]); - } - } - - return outputs; -} - // Recursive interface, get the real kernel that UpdateState node depends on. AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { @@ -414,98 +105,81 @@ AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) { return node; } -// Fetch all parameters in control node of root funcgraph. -std::vector FetchParameterByControlNode(const std::vector &control_nodes) { - std::vector parameters; +// Topologically sort all funcgraphs according to the function call relationship. +std::vector TopoSortForFuncGraph(const FuncGraphPtr &root, FuncGraphCallRelation *edges) { + MS_EXCEPTION_IF_NULL(root->manager()); + std::set nodes; + nodes.emplace(root); - for (const auto &control_node : control_nodes) { - CNodePtr cnode = control_node->cast(); - const auto &inputs = cnode->inputs(); - if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { - break; - } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) { - for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) { - if (inputs[i]->isa()) { - (void)parameters.emplace_back(inputs[i]); + FuncGraphSet subs = root->manager()->func_graphs(); + for (auto sub : subs) { + if (sub != root && root != nullptr) { + nodes.emplace(sub); + } + } + + std::queue que; + for (const auto &node : nodes) { + if (edges->find(node) == edges->end()) { + que.push(node); + } + } + + std::vector result; + while (!que.empty()) { + const auto node = que.front(); + que.pop(); + result.emplace_back(node); + for (auto iter = edges->begin(); iter != edges->end();) { + auto &sub_edges = iter->second; + for (auto sub_iter = sub_edges.begin(); sub_iter != sub_edges.end();) { + if (sub_iter->find(node) != sub_iter->end()) { + sub_edges.erase(sub_iter); + } else { + ++sub_iter; } } - } else if (cnode->input(0)->isa() || IsValueNode(cnode->input(0))) { - for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - if (inputs[i]->isa()) { - (void)parameters.emplace_back(inputs[i]); - } - } - } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) { - if (inputs.size() != kSwitchInputNum) { - MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node); - } - if (inputs[kSwitchCondPos]->isa()) { - (void)parameters.emplace_back(inputs[kSwitchCondPos]); - } - } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { - if (inputs.size() != kSwitchLayerInputNum) { - MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node); - } - if (inputs[kSwitchLayerCondPos]->isa()) { - (void)parameters.emplace_back(inputs[kSwitchLayerCondPos]); + if (sub_edges.empty()) { + que.push(iter->first); + edges->erase(iter++); + } else { + ++iter; } } } - return parameters; + + return result; } -// Get funcgraph from node, the interface only accepts partial node and funcgraph value node. -FuncGraphPtr FetchFuncGraphInNode(const auto &node) { - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { - const auto &func_graph = GetFuncGraphFromPartial(node); +// Fetch all output of node, and this function will not parse the call node. +std::vector FetchAllOutputWithIndex(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector result; - if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) { - return FetchFuncGraphInNode(func_graph->output()); - } else if (IsValueNode(func_graph->output())) { - // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called. - return FetchFuncGraphInNode(func_graph->output()); + const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); + if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) { + const auto &cnode = node_with_index.first->cast(); + const auto &inputs = cnode->inputs(); + + for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { + const auto &tmp_list = FetchAllOutputWithIndex(inputs[i]); + result.insert(result.end(), tmp_list.begin(), tmp_list.end()); } - - return func_graph; - } else if (IsValueNode(node)) { - const auto &func_graph = GetValueNode(node); - - if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) { - // When the output of funcgraph is a funcgraph, it needs to return the funcgraph that is finally called. - return FetchFuncGraphInNode(func_graph->output()); - } else if (IsValueNode(func_graph->output())) { - // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called. - return FetchFuncGraphInNode(func_graph->output()); + } else if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimSwitchLayer)) { + } else if (AnfAlgo::IsCallNode(node)) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_num; ++i) { + result.emplace_back(node, i); } - - return func_graph; + } else { + result.emplace_back(node_with_index); } - return nullptr; + return result; } } // namespace -AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set *call_nodes) { - const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first; - if (!AnfAlgo::IsCallNode(real_node)) { - return real_node; - } - if ((*call_nodes).find(real_node) != (*call_nodes).end()) { - return nullptr; - } - (void)((*call_nodes).insert(real_node)); - - const auto &func_graphs = FetchFuncGraphbyCallNode(real_node); - for (const auto &func_graph : func_graphs) { - const auto &output = FetchRealOutputByCallNode(func_graph->output(), call_nodes); - if (output != nullptr) { - return output; - } - } - return nullptr; -} - -// Return true if the node has Ref abstract. bool HasAbstractRef(const AnfNodePtr &node) { if (node == nullptr) { return false; @@ -514,185 +188,6 @@ bool HasAbstractRef(const AnfNodePtr &node) { return (abs != nullptr) && abs->isa(); } -bool IsSubCallNode(const AnfNodePtr &node) { - if (!node->isa()) { - return false; - } - - const auto inputs = node->cast()->inputs(); - if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { - return false; - } - - const auto &switch_layer_inputs = inputs[0]->cast()->inputs(); - const auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast()->inputs(); - if (tuple_inputs.size() <= kMakeTupleInputStartPos) { - return false; - } - - // Check whether the funcgraph called by the call node returns funcgraph or partial node. - FuncGraphPtr func_graph = nullptr; - if (AnfAlgo::CheckPrimitiveType(tuple_inputs[kMakeTupleInputStartPos], prim::kPrimPartial)) { - const auto &func_graph_node = tuple_inputs[kMakeTupleInputStartPos]->cast()->input(kPartialFuncGraphPos); - func_graph = GetValueNode(func_graph_node); - } else if (tuple_inputs[kMakeTupleInputStartPos]->isa() && - IsValueNode(tuple_inputs[kMakeTupleInputStartPos])) { - func_graph = GetValueNode(tuple_inputs[kMakeTupleInputStartPos]); - } - - const auto &output = func_graph->output(); - return AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial) || - (output->isa() && IsValueNode(output)); -} - -std::vector FetchAllRealInputNodeByParameter(const KernelWithIndex &node) { - std::vector parameters; - const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second); - const auto &real_node = real_node_with_index.first; - if (real_node->isa()) { - if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) { - (void)parameters.emplace_back(real_node_with_index); - } - } else if (HasAbstractMonad(real_node)) { - return parameters; - } else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) { - const auto &inputs = real_node->cast()->inputs(); - for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { - const auto &sub_parameters = FetchAllRealInputNodeByParameter({inputs[i], 0}); - (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); - } - } else { - (void)parameters.emplace_back(real_node_with_index); - } - return parameters; -} - -std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node) { - std::vector func_graphs; - if (!node->isa()) { - return func_graphs; - } - - const auto &call_inputs = node->cast()->inputs(); - if (call_inputs[0]->isa()) { - const auto &cnode = call_inputs[0]->cast(); - const auto &cnode_inputs = cnode->inputs(); - if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { - for (size_t i = kSwitchTrueBranchPos; i < cnode_inputs.size(); ++i) { - if (IsPrimitiveCNode(cnode_inputs[i], prim::kPrimPartial)) { - (void)func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i])); - } else if (IsValueNode(cnode_inputs[i])) { - (void)func_graphs.emplace_back(GetValueNode(cnode_inputs[i])); - } - } - } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer) && - AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) { - const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast()->inputs(); - - // Fetch all funcgraphs in make tuple node. - for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { - const auto func_graph = FetchFuncGraphInNode(tuple_inputs[i]); - if (func_graph != nullptr) { - func_graphs.emplace_back(func_graph); - } - } - } else if (AnfAlgo::IsCallNode(cnode)) { - return FetchFuncGraphbyCallNode(cnode); - } else { - MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString(); - } - } else if (call_inputs[0]->isa() && IsValueNode(call_inputs[0])) { - (void)func_graphs.emplace_back(GetValueNode(call_inputs[0])); - } else { - MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString(); - } - return func_graphs; -} - -size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector *call_nodes) { - if (!AnfAlgo::IsCallNode(node)) { - MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node); - } - if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) { - return 0; - } - (void)((*call_nodes).emplace_back(node)); - - const auto &func_graphs = FetchFuncGraphbyCallNode(node); - for (const auto &func_graph : func_graphs) { - const auto &output = func_graph->output(); - const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0); - - if (AnfAlgo::IsCallNode(real_output.first)) { - size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes); - if (output_num > 0) { - return output_num; - } - } else if (AnfAlgo::CheckPrimitiveType(real_output.first, prim::kPrimMakeTuple)) { - size_t total_num = 0; - const auto &tuple_cnode = real_output.first->cast(); - const auto &inputs = tuple_cnode->inputs(); - size_t i = 1; - for (; i < inputs.size(); ++i) { - if (AnfAlgo::IsCallNode(inputs[i])) { - size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes); - if (call_output_num == 0) { - break; - } - total_num += call_output_num; - } else if (inputs[i]->isa() && inputs[i]->cast()->value()->isa()) { - auto value_tuple = inputs[i]->cast()->value()->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - auto tuple_value = value_tuple->value(); - total_num += tuple_value.size(); - } else if (!HasAbstractMonad(inputs[i])) { - ++total_num; - } - } - if (i == inputs.size()) { - return total_num; - } - } else { - return 1; - } - } - return 0; -} - -FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) { - auto front_node = GetFrontNodeByBackendNode(node); - // If the front node is nullptr, we can check its inputs. - if (front_node == nullptr) { - if (node->isa()) { - const auto &cnode = node->cast(); - const auto &inputs = cnode->inputs(); - - for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - const auto &func_graph = FetchFuncGraphByNode(inputs[i]); - if (func_graph != nullptr) { - return func_graph; - } - } - } else { - return nullptr; - } - } - - const auto &func_graph = front_node->func_graph(); - return func_graph; -} - -AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) { - if (backend_node->func_graph() == nullptr) { - return nullptr; - } - auto kernel_graph = dynamic_cast(backend_node->func_graph().get()); - if (kernel_graph == nullptr) { - return nullptr; - } - return kernel_graph->GetFrontAnfByBackendAnf(backend_node); -} - KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) { const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node); if (front_node != nullptr) { @@ -700,65 +195,57 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const } const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node); if (front_node_with_index.first == nullptr) { - MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node); + MS_LOG(EXCEPTION) << "Cannot find front node for backend node:" << backend_node->DebugString() + << " in graph:" << graph->ToString(); } return front_node_with_index; } -FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) { - auto front_node = GetFrontNodeByBackendNode(backend_node); - if (front_node == nullptr) { - return nullptr; - } - return front_node->func_graph(); -} - void ControlNodeParser::Parse(const std::vector &control_nodes, const std::vector &graphs, - const std::vector &device_contexts, const FuncGraphPtr &root_graph) { + const std::vector &device_contexts, const FuncGraphPtr &root_graph, + const FuncGraphToKernelGraph &func_graph_to_kernel_graphs) { if (graphs.size() != device_contexts.size()) { MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size() << " device context num:" << device_contexts.size(); } - if (graphs.empty()) { + + if (control_nodes.size() <= 1) { return; } + is_inited_ = true; + root_func_graph_ = root_graph; root_graph_parameters_ = root_graph->parameters(); - RealToFormalNode real_to_formal_front_parameters; - FetchFrontToFrontParameter(control_nodes, &real_to_formal_front_parameters); + func_graph_to_kernel_graphs_ = func_graph_to_kernel_graphs; - RealToFormalNode formal_to_real_front_parameters; - for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) { - for (const auto formal_parameter : real_to_formal_front_parameter.second) { - (void)formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first); - } - } + CreateBranchIDForCallNode(control_nodes); - FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters, - formal_to_real_front_parameters); + ParseCallNodeToFuncGraph(control_nodes); - FetchFuncGraphToParameter(control_nodes); + FetchFrontNodeToKernelGraph(graphs); - FetchHostParameterToWeight(real_to_formal_front_parameters); + ParseFormalToRealParameter(control_nodes); + + ParseFrontToBackendParameter(graphs, device_contexts); + + FetchHostParameterToWeight(); FetchCallInputKernelGraph(graphs, device_contexts); - FetchFrontValueNode(control_nodes, graphs, device_contexts); + FetchFrontValueNode(); FetchFrontToBackendKernel(graphs, device_contexts); - FetchCallInputKernelGraph(graphs, device_contexts); + ParseDeviceContext(control_nodes, graphs, device_contexts, func_graph_to_kernel_graphs); - control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]); - - FetchFuncGraphCallNum(control_nodes); - - FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters); + FetchControlNodeParameter(control_nodes); FetchAutoMonadNode(control_nodes); + + ParseFirstControlNodeForFuncGraph(control_nodes); } bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node) { @@ -767,28 +254,165 @@ bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, cons return false; } + // If the graph has a call input, all of its inputs in the graph should be Linked to its stack actor. if (IsCallInputKernelGraph(graph)) { return true; } - const AnfNodePtr &front_node = graph->GetFrontAnfByBackendAnf(node); - return (front_node != nullptr && front_node->isa() && - (!AnfAlgo::IsParameterWeight(front_node->cast()))); + // Parameter input should be Linked to its entrance actor. + const auto &front_node = graph->GetFrontAnfByBackendAnf(node); + return front_node != nullptr && front_node->isa() && + (!AnfAlgo::IsParameterWeight(front_node->cast())); } -std::vector ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr ¶meter) { - return formal_to_real_parameters_[parameter]; +void ControlNodeParser::ParseDeviceContext(const std::vector &control_nodes, + const std::vector &kernel_graphs, + const std::vector &device_contexts, + const FuncGraphToKernelGraph &func_graph_to_kernel_graphs) { + if (device_contexts.empty()) { + MS_LOG(EXCEPTION) << "Invalid device contexts."; + } + + ParseDeviceContextForFuncGraph(control_nodes, kernel_graphs, device_contexts, func_graph_to_kernel_graphs); + ParseDeviceContextForControlNode(device_contexts[0]); } -std::set ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) { - std::set call_nodes; - std::set switch_nodes; - std::set results; - FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes, &results); - return results; +void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector &control_nodes, + const std::vector &kernel_graphs, + const std::vector &device_contexts, + const FuncGraphToKernelGraph &func_graph_to_kernel_graphs) { + std::unordered_map kernel_graph_to_device_context; + for (size_t i = 0; i < kernel_graphs.size(); ++i) { + kernel_graph_to_device_context[kernel_graphs[i]] = device_contexts[i]; + } + const auto &default_context = device_contexts[0]; + + // Collect the device context type of the parameter in the kernel graph as the type of the real parameters. + for (const auto &func_graph_to_kernel_graph : func_graph_to_kernel_graphs) { + const auto &func_graph = func_graph_to_kernel_graph.first; + const auto &front_parameters = func_graph->parameters(); + std::vector parameter_device_contexts(front_parameters.size(), nullptr); + std::unordered_map front_parameter_to_device_context; + + for (const auto &kernel_graph : func_graph_to_kernel_graph.second) { + const auto &backend_parameters = kernel_graph->parameters(); + + for (const auto &backend_parameter : backend_parameters) { + const auto &front_parameter = kernel_graph->GetBackendAnfByFrontAnf(backend_parameter); + if (front_parameter != nullptr && front_parameter->isa()) { + front_parameter_to_device_context[front_parameter] = kernel_graph_to_device_context[kernel_graph]; + } + } + } + + for (size_t i = 0; i < front_parameters.size(); ++i) { + const auto &front_parameter = front_parameters[i]; + const auto &iter = front_parameter_to_device_context.find(front_parameter); + if (iter != front_parameter_to_device_context.end()) { + parameter_device_contexts[i] = iter->second; + } + } + func_graph_to_device_contexts_[func_graph] = parameter_device_contexts; + } + + // If there is no kernel in funcgraph, the parameter uses the default device context type. + FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs(); + for (auto sub_graph : sub_graphs) { + if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) { + func_graph_to_device_contexts_[sub_graph] = + std::vector(sub_graph->parameters().size(), default_context); + } + } } -int ControlNodeParser::GetBranchIDByCallNode(const AnfNodePtr &call_node) { +void ControlNodeParser::ParseDeviceContextForControlNode(const DeviceContext *default_context) { + // Collect the call realationship between funcgraphs. + FuncGraphCallRelation func_graph_call_relation; + for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) { + const auto &call_node = call_node_to_func_graphs.first; + MS_EXCEPTION_IF_NULL(call_node); + const auto &func_graph = call_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + func_graph_call_relation[func_graph].emplace_back(call_node_to_func_graphs.second); + } + + // Topologically sort all funcgraphs according to the function call relationship. + const auto &topo_sort_func_graphs = TopoSortForFuncGraph(root_func_graph_, &func_graph_call_relation); + + // Deduces the device context type of funcgraph outputs according to the topological order. + for (const auto &func_graph : topo_sort_func_graphs) { + MS_EXCEPTION_IF_NULL(func_graph); + const auto &return_node = func_graph->return_node(); + MS_EXCEPTION_IF_NULL(return_node); + const auto &cnode = return_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &inputs = cnode->inputs(); + const auto output_nodes = FetchAllOutputWithIndex(inputs[kReturnInputPos]); + std::vector return_device_contexts; + + for (const auto &output_node : output_nodes) { + if (output_node.first->isa()) { + // If the output is parameter, get the device context type from the formal parameter. + const auto &iter = find(func_graph->parameters().begin(), func_graph->parameters().end(), output_node.first); + if (iter == func_graph->parameters().end()) { + MS_LOG(EXCEPTION) << "Invalid parameter:" << output_node.first->DebugString() + << " for func_graph:" << func_graph->ToString(); + } + const auto &func_graph_iter = func_graph_to_device_contexts_.find(func_graph); + if (func_graph_iter == func_graph_to_device_contexts_.end()) { + MS_LOG(EXCEPTION) << "Cannot find device context for funcgraph:" << func_graph->ToString(); + } + return_device_contexts.emplace_back(func_graph_iter->second[iter - func_graph->parameters().begin()]); + } else if (output_node.first->isa()) { + // If the output is parameter, used the default context type. + return_device_contexts.emplace_back(default_context); + } else if (AnfAlgo::IsCallNode(output_node.first)) { + // If the output is call node, get the device context type by the output of funcgraph. + const auto &func_graphs = call_node_to_func_graphs_[output_node.first]; + std::vector call_device_contexts; + for (const auto &graph : func_graphs) { + MS_EXCEPTION_IF_NULL(graph); + const auto &node = graph->return_node(); + MS_EXCEPTION_IF_NULL(node); + const auto &iter = control_node_to_device_contexts_.find(node); + if (iter != control_node_to_device_contexts_.end()) { + call_device_contexts = iter->second; + break; + } + } + // Since funcgraph has been topo-sorted according to the calling relationship, when there is a call node in + // the output, the output type of the funcgraph called by it should have been determined, if not, an exception + // will be thrown. + if (call_device_contexts.empty() || call_device_contexts.size() <= output_node.second) { + MS_LOG(EXCEPTION) << "Cannot find device context for call node:" << output_node.first->DebugString() + << " device contexts size:" << call_device_contexts.size(); + } + return_device_contexts.emplace_back(call_device_contexts[output_node.second]); + } else if (output_node.first->isa()) { + // If the output is a cnode, get the device context type by the kernel. + const auto &iter = front_to_backend_kernels_.find(output_node); + if (iter == front_to_backend_kernels_.end()) { + MS_LOG(EXCEPTION) << "Cannot find backend kernel for cnode:" << output_node.first->DebugString(); + } + return_device_contexts.emplace_back(iter->second.second); + } else { + MS_LOG(EXCEPTION) << "Invalid node for return:" << output_node.first->DebugString(); + } + } + control_node_to_device_contexts_[return_node] = return_device_contexts; + } +} + +void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector &graphs) { + for (const auto &graph : graphs) { + const auto &graph_outputs = graph->graph_output_map(); + for (const auto &backend_to_front : graph_outputs) { + front_node_to_kernel_graph_[backend_to_front.second.first] = graph; + } + } +} + +int ControlNodeParser::FetchBranchIDByCallNode(const AnfNodePtr &call_node) { MS_EXCEPTION_IF_NULL(call_node); if (call_node_to_branch_id_.find(call_node) == call_node_to_branch_id_.end()) { @@ -797,6 +421,14 @@ int ControlNodeParser::GetBranchIDByCallNode(const AnfNodePtr &call_node) { return call_node_to_branch_id_[call_node]; } +FuncGraphPtr ControlNodeParser::FetchKernelGraphByFrontNode(const AnfNodePtr &kernel) { + const auto &iter = front_node_to_kernel_graph_.find(kernel); + if (iter == front_node_to_kernel_graph_.end()) { + return nullptr; + } + return iter->second; +} + bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) { if (call_input_kernel_graphs_.find(graph) == call_input_kernel_graphs_.end()) { return false; @@ -804,316 +436,171 @@ bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) { return true; } -bool ControlNodeParser::IsKernelInRootFuncGraph(const AnfNodePtr &kernel) { - if (kernel == nullptr) { - return true; +KernelWithIndex ControlNodeParser::FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index) { + const auto &iter = front_to_backend_kernels_.find(node_with_index); + if (iter != front_to_backend_kernels_.end()) { + return iter->second.first; } - - const auto &graph = kernel->func_graph(); - if (kernel != nullptr && graph != nullptr) { - const auto &kernel_graph = dynamic_cast(graph.get()); - if (kernel_graph == nullptr) { - return true; - } - - const auto func_graph = kernel_graph->GetFuncGraph(); - if (func_graph != nullptr && func_graph != root_func_graph_) { - return false; - } - } - - return true; + return {}; } -size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) { - if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) { - MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString(); - } - - return func_graph_to_call_num_[func_graph]; -} - -std::vector ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) { - std::vector call_nodes; - return FetchFuncGraphOutput(func_graph, &call_nodes); -} - -DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node) { - auto iter = std::find_if( - front_value_nodes_.begin(), front_value_nodes_.end(), - [value_node](const auto &front_node_with_context) { return front_node_with_context.first == value_node; }); - if (iter != front_value_nodes_.end()) { - return iter->second; - } - return nullptr; -} - -AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &node) { - for (const auto &host_parameter_to_weight : host_parameter_to_weights_) { - for (const auto &front_weight : host_parameter_to_weight.second) { - if (front_weight == node) { - const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first); - if (iter != front_to_backend_parameters_.end()) { - return iter->second.first; - } - } - } - } - - return nullptr; -} - -void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, - std::vector *value_nodes) { - const auto &cnode = switch_node->cast(); - const auto &inputs = cnode->inputs(); - if (inputs.size() != kSwitchInputNum) { - MS_LOG(EXCEPTION) << "Invalid switch node input num:" << inputs.size(); - } - - for (const auto &input : inputs) { - if (input->isa()) { - const auto &node_value = input->cast()->value(); - if (node_value->isa()) { - (void)((*value_nodes).emplace_back(input)); - } - } else if (AnfAlgo::IsCallNode(input)) { - // If input is a call not, should check the switch node in its input. - const auto &call_node = input->cast(); - const auto &call_inputs = call_node->inputs(); - if (call_inputs.empty() || (!AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch))) { +void ControlNodeParser::FetchFrontValueNode() { + for (const auto &formal_to_real_parameter : formal_to_real_parameters_) { + for (const auto &real_parameter_with_index : formal_to_real_parameter.second) { + const auto &real_parameter = real_parameter_with_index.first; + if (!real_parameter->isa()) { continue; } - FetchValueNodeBySwitchNode(call_inputs[0], value_nodes); - } else if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) { - const auto &partial_node = input->cast(); - const auto &partial_inputs = partial_node->inputs(); - if (partial_inputs.size() <= kPartialFuncGraphPos) { - MS_LOG(EXCEPTION) << "Invalid partial node input num:" << partial_inputs.size(); - } - // if input is a partial node, get the value node in its funcgraph. - const auto &func_graph = GetValueNode(partial_inputs[kPartialFuncGraphPos]); - if (func_graph->output()->isa()) { - (void)((*value_nodes).emplace_back(func_graph->output())); + const auto &iter = front_to_backend_parameters_.find({real_parameter, 0}); + if (iter != front_to_backend_parameters_.end() && (!iter->second.empty())) { + front_value_nodes_.emplace(real_parameter, iter->second.begin()->second); } } } } -void ControlNodeParser::FetchFrontValueNode(const std::vector &control_nodes, - const std::vector &graphs, - const std::vector &device_contexts) { +void ControlNodeParser::ParseFormalToRealParameter(const std::vector &control_nodes) { + std::unordered_map> formal_to_real_parameters; + + // The actual parameters of the function are divided into two parts: + // 1. Input of partial node. + // 2. Input of call node. + for (const auto &node : control_nodes) { + if (AnfAlgo::IsCallNode(node)) { + const auto &cnode = node->cast(); + const auto &inputs = cnode->inputs(); + const auto &func_graphs = FetchFuncGraphbyCallNode(node); + for (const auto func_graph : func_graphs) { + const auto ¶meters = func_graph->parameters(); + for (size_t i = inputs.size() - 1, j = parameters.size() - 1; i >= kCallInputStartPos && j >= 0; --i, --j) { + std::set real_parameters; + std::set invalid_call_nodes; + MS_EXCEPTION_IF_NULL(inputs[i]); + MS_EXCEPTION_IF_NULL(parameters[j]); + FetchRealParameterByNode({inputs[i], 0}, &real_parameters, &invalid_call_nodes); + if (real_parameters.empty()) { + MS_LOG(EXCEPTION) << "Failed to find real parameter for formal parameter:" << inputs[i]->DebugString(); + } + formal_to_real_parameters[parameters[j]].insert(real_parameters.begin(), real_parameters.end()); + } + } + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { + const auto &cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &inputs = cnode->inputs(); + if (inputs.size() <= kPartialFuncGraphPos || (!inputs[kPartialFuncGraphPos]->isa()) || + (!IsValueNode(inputs[kPartialFuncGraphPos]))) { + MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString(); + } + const auto &func_graph = GetValueNode(inputs[kPartialFuncGraphPos]); + MS_EXCEPTION_IF_NULL(func_graph); + const auto ¶meters = func_graph->parameters(); + if (inputs.size() - kPartialInputStartPos > parameters.size()) { + MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size() + << " formal parameter size:" << parameters.size(); + } + for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) { + std::set real_parameters; + std::set invalid_call_nodes; + MS_EXCEPTION_IF_NULL(inputs[i]); + MS_EXCEPTION_IF_NULL(parameters[i - kPartialInputStartPos]); + FetchRealParameterByNode({inputs[i], 0}, &real_parameters, &invalid_call_nodes); + if (real_parameters.empty()) { + MS_LOG(EXCEPTION) << "Failed to find real parameter for formal parameter:" << inputs[i]->DebugString(); + } + formal_to_real_parameters[parameters[i - kPartialInputStartPos]].insert(real_parameters.begin(), + real_parameters.end()); + } + } + } + + // When the real parameter is also a parameter, the corresponding actual parameter needs to be obtained recursively. + for (const auto &formal_to_real_parameter : formal_to_real_parameters) { + const auto &formal_parameter = formal_to_real_parameter.first; + const auto &real_parameters = formal_to_real_parameter.second; + std::set total_real_parameters = real_parameters; + for (const auto &real_parameter : real_parameters) { + if (real_parameter.first->isa()) { + std::set invalid_real_parameter{formal_parameter}; + ParseAllRealParameterByFormalParameter(real_parameter.first, formal_to_real_parameters, &total_real_parameters, + &invalid_real_parameter); + real_to_formal_parameters_[real_parameter.first].emplace(formal_parameter); + } else { + total_real_parameters.emplace(real_parameter); + } + } + std::swap(formal_to_real_parameters_[formal_parameter], total_real_parameters); + } +} + +void ControlNodeParser::ParseAllRealParameterByFormalParameter(const AnfNodePtr &formal_parameter, + const FormalToRealParameter &formal_to_real_parameters, + std::set *total_real_parameters, + std::set *invalid_real_parameter) { + if (invalid_real_parameter->find(formal_parameter) != invalid_real_parameter->end()) { + return; + } + invalid_real_parameter->emplace(formal_parameter); + + // Get all the actual parameters corresponding to parameter recursively. + const auto &dst_iter = formal_to_real_parameters_.find(formal_parameter); + if (dst_iter != formal_to_real_parameters_.end()) { + total_real_parameters->insert(dst_iter->second.begin(), dst_iter->second.end()); + return; + } + const auto &src_iter = formal_to_real_parameters.find(formal_parameter); + if (src_iter == formal_to_real_parameters.end()) { + const auto &func_graph = formal_parameter->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + if (func_graph == root_func_graph_) { + return; + } + MS_LOG(EXCEPTION) << "Invalid formal parameter:" << formal_parameter->DebugString(); + } + const auto &real_parameters = src_iter->second; + for (const auto &real_parameter : real_parameters) { + MS_EXCEPTION_IF_NULL(real_parameter.first); + if (real_parameter.first->isa()) { + ParseAllRealParameterByFormalParameter(real_parameter.first, formal_to_real_parameters, total_real_parameters, + invalid_real_parameter); + } else { + total_real_parameters->emplace(real_parameter); + } + } +} + +void ControlNodeParser::FetchControlNodeParameter(const std::vector &control_nodes) { for (const auto &control_node : control_nodes) { CNodePtr cnode = control_node->cast(); - auto inputs = cnode->inputs(); - if (inputs[0]->isa() && IsValueNode(inputs[0])) { - auto func_graph = GetValueNode(inputs[0]); - const auto parameters = func_graph->parameters(); - if (parameters.size() != inputs.size() - kCallInputStartPos) { - MS_LOG(EXCEPTION) << "Invalid parameters num, need:" << parameters.size() - << " has:" << inputs.size() - kCallInputStartPos; - } - for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - if (inputs[i]->isa()) { - const auto &node_value = inputs[i]->cast()->value(); - if (!node_value->isa()) { - continue; - } - if (front_to_backend_parameters_.find(parameters[i - kCallInputStartPos]) == - front_to_backend_parameters_.end()) { - MS_LOG(INFO) << "Cannot find backend parameter for front parameter:" - << AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos]) - << ", used the default format"; - CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]); - (void)front_value_nodes_.emplace_back(inputs[i], device_contexts[0]); - continue; - } - - const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first; - const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second; - CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context); - (void)front_value_nodes_.emplace_back(inputs[i], device_context); - } - } - } - } - - for (size_t index = 0; index < graphs.size(); ++index) { - const auto &graph = graphs[index]; - MS_EXCEPTION_IF_NULL(graph); - - for (const auto ¶meter : graph->input_nodes()) { - MS_EXCEPTION_IF_NULL(parameter); - - if (IsInternalParameter(parameter, graph)) { - auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter); - MS_EXCEPTION_IF_NULL(front_node_with_index.first); - const auto &front_output_with_index = - AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false); - auto front_output_node = front_output_with_index.first; - MS_EXCEPTION_IF_NULL(front_output_node); - if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) { - std::vector value_nodes; - FetchValueNodeBySwitchNode(front_output_node, &value_nodes); - for (const auto value_node : value_nodes) { - CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]); - (void)front_value_nodes_.emplace_back(value_node, device_contexts[index]); - } - } - } - } - } - - // When funcgraph called by call node returns to the value node, device addresses should be created for these - // value nodes. - for (const auto &call_node_to_backend_parameter : call_node_to_backend_parameters_) { - const auto func_graphs = FetchFuncGraphbyCallNode(call_node_to_backend_parameter.first.first); - for (const auto &func_graph : func_graphs) { - const auto &output = func_graph->output(); - if (output->isa() && GetFrontValueNodeDeviceContext(output) == nullptr) { - const auto &device_context = call_node_to_backend_parameter.second.second; - CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context); - (void)front_value_nodes_.emplace_back(output, device_context); - } - } - } -} - -void ControlNodeParser::FetchFrontToFrontParameter( - const std::vector &control_nodes, - std::unordered_map> *front_to_front_parameter) { - // Function used to collect the input of call node. - const auto &call_input_parse = [front_to_front_parameter](const std::vector ¶meters, - const std::vector &call_inputs, - const size_t call_input_start_pos) { - for (size_t i = 0; i < call_inputs.size(); ++i) { - if (call_inputs[i]->isa()) { - (*front_to_front_parameter)[call_inputs[i]].push_back(parameters[i + call_input_start_pos]); - } - } - }; - - // Function used to collect the input of partial node. - const auto &partial_input_parse = [call_input_parse, front_to_front_parameter]( - const AnfNodePtr &partial_node, const std::vector &call_inputs) { - const auto &cnode = partial_node->cast(); const auto &inputs = cnode->inputs(); - const auto &func_graph = GetValueNode(inputs[kPartialFuncGraphPos]); - const auto ¶meters = func_graph->parameters(); - for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) { - if (inputs[i]->isa()) { - (*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kPartialInputStartPos]); - } - } - call_input_parse(parameters, call_inputs, inputs.size() - kPartialInputStartPos); - }; - - // Function used to collect the input of switch node. - const auto &switch_input_parse = [&](const AnfNodePtr &switch_node, const std::vector &call_inputs) { - CNodePtr cnode = switch_node->cast(); - const auto &switch_inputs = cnode->inputs(); - if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch)) { - // Parse the switch node. The switch node has two partial node inputs. - if (AnfAlgo::CheckPrimitiveType(switch_inputs[kSwitchTrueBranchPos], prim::kPrimPartial)) { - partial_input_parse(switch_inputs[kSwitchTrueBranchPos], call_inputs); - partial_input_parse(switch_inputs[kSwitchFalseBranchPos], call_inputs); - } - } else { - // Parse the switchlayer node. The switchlayer node has a maketuple node input, which is a tuple of funcgraphs. - // call_inputs will be the input of these funcgraphs. - const auto &tuple_node = switch_inputs[kSwitchLayerBranchPos]->cast(); - const auto &tuple_inputs = tuple_node->inputs(); - for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { - const auto &input = tuple_inputs[i]; - if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) { - partial_input_parse(input, call_inputs); - } else { - auto func_graph = GetValueNode(input); - call_input_parse(func_graph->parameters(), call_inputs, 0); + if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { + break; + } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) { + for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) { + if (inputs[i]->isa()) { + (void)control_node_parameters_.emplace_back(inputs[i]); } } - } - }; - - for (const auto &node : control_nodes) { - CNodePtr cnode = node->cast(); - const auto &inputs = cnode->inputs(); - if (inputs[0]->isa() && IsValueNode(inputs[0])) { - // Call node which the first input node is a valuenode of funcgraph. - const auto &func_graph = GetValueNode(inputs[0]); - const auto ¶meters = func_graph->parameters(); + } else if (cnode->input(0)->isa() || IsValueNode(cnode->input(0))) { for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { if (inputs[i]->isa()) { - (*front_to_front_parameter)[inputs[i]].push_back(parameters[i - kCallInputStartPos]); + (void)control_node_parameters_.emplace_back(inputs[i]); } } - } else if (inputs[0]->isa()) { - // Call node which the first input node is a switch or switchlayer node. - if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) || - AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { - std::vector call_inputs; - call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end()); - switch_input_parse(inputs[0], call_inputs); - } else if (AnfAlgo::IsCallNode(inputs[0])) { - continue; - } else { - MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:" - << AnfAlgo::GetNodeDebugString(inputs[0]); + } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) { + if (inputs.size() != kSwitchInputNum) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node); } - } - } -} - -std::vector ControlNodeParser::FetchControlNodeParameter(const std::vector &control_nodes, - DeviceContext *device_context) { - std::vector parameters = FetchParameterByControlNode(control_nodes); - - for (const auto &graph_with_device_context : call_input_kernel_graphs_) { - const auto &graph = graph_with_device_context.first; - const auto &func_graph = graph->GetFuncGraph(); - if (func_graph == nullptr) { - MS_LOG(WARNING) << "Cannot get funcgraph by kernel graph:" << graph->ToString(); - continue; - } - if (func_graph != root_func_graph_) { - continue; - } - - const auto &inputs = graph->input_nodes(); - for (const auto &input : inputs) { - const auto &front_node = graph->GetFrontAnfByBackendAnf(input); - if (front_node != nullptr && front_node->isa() && (!HasAbstractRef(front_node))) { - (void)parameters.emplace_back(front_node); + if (inputs[kSwitchCondPos]->isa()) { + (void)control_node_parameters_.emplace_back(inputs[kSwitchCondPos]); } - } - } - - for (const auto ¶meter : parameters) { - auto backend_iter = front_to_backend_parameters_.find(parameter); - if (backend_iter == front_to_backend_parameters_.end()) { - CreateDeviceTensorForFrontParameter(parameter, device_context); - front_to_backend_parameters_[parameter] = {parameter, device_context}; - (void)front_parameters_.emplace_back(parameter, device_context); - } - } - - return parameters; -} - -void ControlNodeParser::FetchFuncGraphCallNum(const std::vector &control_nodes) { - for (const auto &control_node : control_nodes) { - if (AnfAlgo::IsCallNode(control_node)) { - const auto &func_graphs = FetchFuncGraphbyCallNode(control_node); - - for (const auto &func_graph : func_graphs) { - MS_EXCEPTION_IF_NULL(func_graph); - - if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) { - func_graph_to_call_num_[func_graph] = 1; - } else { - func_graph_to_call_num_[func_graph]++; - } + } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { + if (inputs.size() != kSwitchLayerInputNum) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node); + } + if (inputs[kSwitchLayerCondPos]->isa()) { + (void)control_node_parameters_.emplace_back(inputs[kSwitchLayerCondPos]); } } } @@ -1130,80 +617,24 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vectorGetFrontNodeByInternalParameter(input); if (internal_parameter_with_index.first != nullptr && AnfAlgo::IsCallNode(internal_parameter_with_index.first)) { call_input_kernel_graphs_[graph] = device_context; - call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context}; } } } } -std::vector FetchInputParameterbyControlNode(const AnfNodePtr &node, std::set *switch_nodes, - std::set *call_nodes) { - std::vector parameters; +void ControlNodeParser::CreateBranchIDForCallNode(const std::vector &control_nodes) { + int branch_id = kMainBranchID; - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { - if ((*switch_nodes).find(node) != (*switch_nodes).end()) { - return parameters; - } - (void)(*switch_nodes).insert(node); - - const auto &cnode = node->cast(); - const auto &inputs = cnode->inputs(); - if (inputs.size() != kSwitchInputNum) { - MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(node); - } - - for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { - if (inputs[i]->isa()) { - (void)parameters.emplace_back(inputs[i]); - } else if (AnfAlgo::IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { - const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes); - (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); - } - } - } else if (AnfAlgo::IsCallNode(node)) { - if ((*call_nodes).find(node) != (*call_nodes).end()) { - return parameters; - } - (void)(*call_nodes).insert(node); - - const auto &func_graphs = FetchFuncGraphbyCallNode(node); - for (const auto &func_graph : func_graphs) { - if (func_graph->output()->isa()) { - (void)parameters.emplace_back(func_graph->output()); - } + for (const auto &control_node : control_nodes) { + // Root funcgraph does not need to create a gather actor. + if (AnfAlgo::IsCallNode(control_node)) { + call_node_to_branch_id_[control_node] = ++branch_id; } } - return parameters; } -std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) { - std::vector parameters; - const auto &graph_parameters = graph->input_nodes(); - - for (const auto &graph_parameter : graph_parameters) { - const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter); - const auto &internal_front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter); - const auto &internal_front_node = internal_front_node_with_index.first; - - if (external_front_node == nullptr && internal_front_node == nullptr) { - MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :" - << AnfAlgo::GetNodeDebugString(graph_parameter); - continue; - } - - const auto &front_node_with_index = - ((external_front_node != nullptr) ? KernelWithIndex(external_front_node, 0) : internal_front_node_with_index); - const auto &sub_parameters = FetchAllRealInputNodeByParameter(front_node_with_index); - (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); - } - - return parameters; -} - -void ControlNodeParser::FetchFrontToBackendParameter(const std::vector &graphs, - const std::vector &device_contexts, - const RealToFormalNode &real_to_formal_front_parameters, - const RealToFormalNode &formal_to_real_front_parameters) { +void ControlNodeParser::ParseFrontToBackendParameter(const std::vector &graphs, + const std::vector &device_contexts) { if (graphs.size() != device_contexts.size()) { MS_LOG(EXCEPTION) << "Graph num is not equal to device context num."; } @@ -1213,41 +644,67 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vectorinput_nodes()) { - auto front_node = graph->GetFrontAnfByBackendAnf(parameter); - if (front_node != nullptr && front_node->isa() && - front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) { - front_to_backend_parameters_[front_node] = {parameter, device_context}; + const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter); + const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter); + if (front_node == nullptr && front_node_with_index.first == nullptr) { + MS_LOG(EXCEPTION) << "Invalid backend parameter:" << parameter->DebugString() + << " for kernel graph:" << graph->ToString(); + } + + if (front_node_with_index.first != nullptr) { + std::set real_parameters; + std::set invalid_call_nodes; + FetchRealParameterByNode(front_node_with_index, &real_parameters, &invalid_call_nodes); + for (const auto real_parameter : real_parameters) { + if (real_parameter.first->isa() || real_parameter.first->isa()) { + front_to_backend_parameters_[real_parameter.first].emplace(parameter, device_context); + } + } + } else { + front_to_backend_parameters_[front_node].emplace(parameter, device_context); } } } - // This for loop cannot be combined with the for loop above, because the relationship between front - // and backend needs to be consistent with HostDataSource. - for (size_t i = 0; i < graphs.size(); ++i) { - const auto &graph = graphs[i]; - auto device_context = device_contexts[i]; - for (const auto ¶meter : graph->input_nodes()) { - const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(parameter); - - if (internal_front_node.first != nullptr) { - std::set call_nodes; - std::set switch_nodes; - const auto &front_paramters = - FetchInputParameterbyControlNode(internal_front_node.first, &switch_nodes, &call_nodes); - for (const auto &front_paramter : front_paramters) { - if (front_to_backend_parameters_.find(front_paramter) == front_to_backend_parameters_.end()) { - front_to_backend_parameters_[front_paramter] = {parameter, device_context}; - } + // Get the corresponding backend node for the real parameter according to the relationship between real + // parameter and formal parameter. + for (const auto &front_to_backend_parameters : front_to_backend_parameters_) { + const auto &front_parameter = front_to_backend_parameters.first; + const auto &backend_parameters = front_to_backend_parameters.second; + const auto &iter = formal_to_real_parameters_.find(front_parameter); + if (iter != formal_to_real_parameters_.end()) { + for (const auto &real_parameter_with_index : iter->second) { + const auto &real_parameter = real_parameter_with_index.first; + if (real_parameter->isa()) { + front_to_backend_parameters_[real_parameter].insert(backend_parameters.begin(), backend_parameters.end()); } } } } } -void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front_to_front_parameters) { - for (const auto &pair : front_to_front_parameters) { - std::vector dest_nodes; - FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters); +void ControlNodeParser::ParseCallNodeToFuncGraph(const std::vector &control_nodes) { + for (const auto &control_node : control_nodes) { + MS_EXCEPTION_IF_NULL(control_node); + + if (AnfAlgo::IsCallNode(control_node)) { + call_node_to_func_graphs_[control_node] = AnfAlgo::GetFuncGraphbyCallNode(control_node); + } + } +} + +const std::set &ControlNodeParser::FetchFuncGraphbyCallNode(const AnfNodePtr &control_node) { + const auto &iter = call_node_to_func_graphs_.find(control_node); + if (iter == call_node_to_func_graphs_.end()) { + MS_LOG(EXCEPTION) << "Invalid call node:" << control_node->DebugString(); + } + return iter->second; +} + +void ControlNodeParser::FetchHostParameterToWeight() { + for (const auto &pair : real_to_formal_parameters_) { + std::set dest_nodes; + FetchWeightbyHostParameter(pair.first, &dest_nodes, real_to_formal_parameters_); host_parameter_to_weights_[pair.first] = dest_nodes; if (std::find(root_graph_parameters_.begin(), root_graph_parameters_.end(), pair.first) != @@ -1259,51 +716,6 @@ void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front } } -FuncGraphPtr ControlNodeParser::FetchKernelGraphByFrontNode(const AnfNodePtr &kernel) { - const auto &iter = front_node_to_kernel_graph_.find(kernel); - if (iter == front_node_to_kernel_graph_.end()) { - return nullptr; - } - return iter->second; -} - -void ControlNodeParser::FetchFuncGraphToParameter(const std::vector &control_nodes) { - for (const auto &control_node : control_nodes) { - const auto &cnode = control_node->cast(); - const auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node); - } - - // Call node which the first input is a cnode. - if (inputs[0]->isa()) { - const auto &switch_cnode = inputs[0]->cast(); - - if (AnfAlgo::CheckPrimitiveType(switch_cnode, prim::kPrimSwitch)) { - // Switch node. - FetchParameterBySwitchNode(inputs[0], &func_graph_to_parameters_); - } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { - // Switchlayer node. - FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_); - } else if (AnfAlgo::IsCallNode(inputs[0])) { - continue; - } else { - MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString(); - } - } else if (inputs[0]->isa() && IsValueNode(inputs[0])) { - // Call node which the first input is a value node of funcgraph. - const auto &func_graph = GetValueNode(inputs[0]); - std::vector parameters; - for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - if (CheckValidFuncGraphInput(inputs[i])) { - (void)parameters.emplace_back(inputs[i]); - } - } - (void)func_graph_to_parameters_[func_graph].emplace_back(parameters); - } - } -} - void ControlNodeParser::FetchFrontToBackendKernel(const std::vector &graphs, const std::vector &device_contexts) { for (size_t i = 0; i < graphs.size(); ++i) { @@ -1312,15 +724,13 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vectorexecution_order(); for (auto &kernel : execution_order) { - if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) { - auto front_node = graph->GetFrontAnfByBackendAnf(kernel); - if (front_node != nullptr) { - for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) { - front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context}; - MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(front_node) - << "index:" << j << " addr:" << front_node - << " second:" << AnfAlgo::GetNodeDebugString(kernel) << "index:" << j << " addr:" << kernel; - } + auto front_node = graph->GetFrontAnfByBackendAnf(kernel); + if (front_node != nullptr) { + for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(kernel); ++j) { + front_to_backend_kernels_[{front_node, j}] = {{kernel, j}, device_context}; + MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(front_node) + << "index:" << j << " addr:" << front_node << " second:" << AnfAlgo::GetNodeDebugString(kernel) + << "index:" << j << " addr:" << kernel; } } } @@ -1332,182 +742,6 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector *call_nodes, - std::set *switch_nodes, - std::set *results) { - if (front_output->isa()) { - (void)(*results).emplace(front_output, 0); - - const auto &iter = formal_to_real_parameters_.find(front_output); - if (iter != formal_to_real_parameters_.end()) { - for (const auto &node : iter->second) { - (void)(*results).emplace(node); - } - } - } else if (front_output->isa()) { - // Output is a parameter. - const auto iter = formal_to_real_parameters_.find(front_output); - if (iter != formal_to_real_parameters_.end()) { - for (const auto &node : iter->second) { - (void)(*results).emplace(node); - } - } else { - MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output); - } - } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimSwitch)) { - // Output is a switch. - const auto &switch_outputs = FetchOutputBySwitchNode(front_output, call_nodes, switch_nodes); - - for (const auto &switch_output : switch_outputs) { - FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results); - } - } else if (AnfAlgo::IsCallNode(front_output)) { - // Output is a call. - const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes); - - for (const auto &call_output : call_outputs) { - FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes, results); - } - } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimMakeTuple)) { - // Output is a make tuple. - const auto &cnode = front_output->cast(); - const auto &inputs = cnode->inputs(); - - for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { - FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes, results); - } - } else if (front_output->isa()) { - // Output is a kernel. - const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0)); - if (iter != front_to_backend_kernels_.end()) { - (void)(*results).emplace(iter->second.first); - } else { - MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output); - } - } else { - MS_LOG(EXCEPTION) << "Invalid front node:" << AnfAlgo::GetNodeDebugString(front_output); - } -} - -KernelWithIndex ControlNodeParser::FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index) { - const auto &iter = front_to_backend_kernels_.find(node_with_index); - if (iter != front_to_backend_kernels_.end()) { - return iter->second.first; - } - return {}; -} - -void ControlNodeParser::FetchBackendInputNodebyFrontNode( - const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter, - const FrontToBackendNodeWithContext &front_to_backend_parameters) { - if (real_parameter->isa()) { - // Input node is a parameter from host data source actor. - std::set invalid_inputs; - std::vector front_inputs = - FetchInputNodeByParameter(real_parameter, root_graph_parameters_, &invalid_inputs, func_graph_to_parameters_); - - for (const auto &front_input : front_inputs) { - const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0); - if (node_with_index.first->isa()) { - const auto &iter = front_to_backend_parameters.find(real_parameter); - if (iter == front_to_backend_parameters.end()) { - MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); - continue; - } - (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0); - } else { - const auto iter = front_to_backend_kernels_.find(node_with_index); - if (iter == front_to_backend_kernels_.end()) { - MS_LOG(EXCEPTION) << "Cannot find actor of front node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); - } - (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first); - } - } - } else if (real_parameter->isa()) { - (void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0); - } else if (AnfAlgo::IsCallNode(real_parameter)) { - const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); - for (const auto func_graph : func_graphs) { - FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters); - } - } else { - // Input node is a cnode. - const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(real_parameter, 0); - const auto iter = front_to_backend_kernels_.find(node_with_index); - if (iter == front_to_backend_kernels_.end()) { - MS_LOG(EXCEPTION) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); - } - (void)formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first); - } -} - -void ControlNodeParser::FetchBackendParameterNode(const std::vector &graphs, - const std::vector &device_contexts, - const RealToFormalNode &real_to_formal_front_parameters, - const RealToFormalNode &formal_to_real_front_parameters, - FrontToBackendNodeWithContext *front_to_backend_parameters) {} - -void ControlNodeParser::FetchBackendInputNode(const std::vector &graphs, - const std::vector &device_contexts, - const RealToFormalNode &real_to_formal_front_parameters, - const RealToFormalNode &formal_to_real_front_parameters) { - FrontToBackendNodeWithContext front_to_backend_parameters; - FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters, - &front_to_backend_parameters); - - for (size_t i = 0; i < graphs.size(); ++i) { - const auto &graph = graphs[i]; - for (const auto &value_node : graph->graph_value_nodes()) { - auto front_node = graph->GetFrontAnfByBackendAnf(value_node); - if (front_node != nullptr) { - (void)formal_to_real_parameters_[front_node].emplace_back(value_node, 0); - } - } - } - - for (const auto &host_parameter_to_weight : host_parameter_to_weights_) { - for (const auto &front_weight : host_parameter_to_weight.second) { - const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first); - if (iter != front_to_backend_parameters_.end()) { - (void)formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0); - } - } - } - - for (const auto &func_graph_to_parameters : func_graph_to_parameters_) { - const auto &func_graph = func_graph_to_parameters.first; - std::vector graph_inputs; - for (const auto &input : func_graph->get_inputs()) { - // Monad input would not send to gather actor. - if (HasAbstractMonad(input) || (input->isa() && HasAbstractRef(input))) { - continue; - } - (void)graph_inputs.emplace_back(input); - } - - // Collect all backend input node to gather, There are two situations: - // 1. The parameter from the host data source. - // 2. Output the kernel actor. - for (const auto parameters : func_graph_to_parameters.second) { - if (parameters.size() != graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() - << " need:" << graph_inputs.size() << " func_graph:" << func_graph->ToString(); - } - - for (size_t i = 0; i < parameters.size(); ++i) { - FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i], front_to_backend_parameters); - } - } - } - for (const auto parameter_pair : front_to_backend_parameters) { - (void)formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0); - } - for (const auto parameter_pair : front_to_backend_parameters_) { - (void)formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0); - } -} - void ControlNodeParser::FetchAutoMonadNode(const std::vector &control_nodes) { for (const auto &control_node : control_nodes) { const auto &cnode = control_node->cast(); @@ -1536,5 +770,34 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod } return sub_front_node_to_root_front_node_[sub_front_node]; } + +bool IsFirstControlNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return true; + } + + const auto &cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &inputs = cnode->inputs(); + for (const auto &input : inputs) { + MS_EXCEPTION_IF_NULL(input); + if (AnfAlgo::IsCallNode(input) || (!IsFirstControlNode(input))) { + return false; + } + } + return true; +} + +void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector &control_nodes) { + for (const auto &control_node : control_nodes) { + if ((AnfAlgo::IsCallNode(control_node) || AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) && + IsFirstControlNode(control_node)) { + const auto &func_graph = control_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + func_graph_to_first_control_nodes_[func_graph].emplace(control_node); + } + } +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 6fa284ff386..f86cd1fa79b 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -56,102 +57,55 @@ constexpr size_t kSingleControlNode = 1; const char kEntranceActorNameSuffix[] = "_EntranceActor"; const char kStackActorNameSuffix[] = "_StackActor"; -using FrontToBackendNodeWithContext = std::unordered_map>; +using FrontToBackendNodeWithContext = std::unordered_map>>; using FrontToBackendKernelWithContext = std::map>; using FuncGraphToKernelGraph = std::unordered_map>; -using FuncGraphToParameter = std::unordered_map>>; -using HostParameterToWeight = std::unordered_map>; -using NodeWithDeviceContext = std::vector>; +using HostParameterToWeight = std::unordered_map>; +using NodeWithDeviceContext = std::set>; using RealToFormalNode = std::unordered_map>; +using FormalToRealParameter = std::unordered_map>; +using RealToFormalParameter = std::unordered_map>; +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; using FrontNodeToKernelGraph = std::unordered_map; - -// Check if the call node is the input of another call node. -bool IsSubCallNode(const AnfNodePtr &node); - -// Recursive interface, find the real output of funcgraph called by call node. -AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set *call_nodes); +using FuncGraphCallRelation = std::unordered_map>>; // Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph, // it is determined whether it is a weight. bool HasAbstractRef(const AnfNodePtr &node); - -// Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph -// which the front node belongs, if not, find the funcgraph which the input of the node belongs. -FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node); - -// Recursive interface, get the number of output nodes of funcgraph called by call node. -size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector *call_nodes); - -// Get front node by backend node. -AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node); - // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the // corresponding cnode. KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph); -// Get the funcgraph to which the node belongs. -FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node); - -// Find all funcgraphs that the call node will call. -std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node); - -// Get parameters in kernel graph. -std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph); - // ControlNodeParser is used to parse control nodes, and get the edges between nodes. class ControlNodeParser { public: // Parse the control node and put the results of the parsing into member variables. void Parse(const std::vector &control_nodes, const std::vector &graphs, - const std::vector &device_contexts, const FuncGraphPtr &root_graph); + const std::vector &device_contexts, const FuncGraphPtr &root_graph, + const FuncGraphToKernelGraph &func_graph_to_kernel_graphs); bool IsInited() { return is_inited_; } + // Check whether there is a call node in the front input nodes of the kernel graph. + bool IsCallInputKernelGraph(const KernelGraphPtr &graph); + // Check whether the data arrow of the kernel actor needs to be connected to the control actor. + // There are two situations: + // 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph. + // 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor. bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node); + const std::vector &control_node_parameters() const { return control_node_parameters_; } const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; } const HostParameterToWeight &host_parameter_to_weights() const { return host_parameter_to_weights_; } const NodeWithDeviceContext &front_value_nodes() const { return front_value_nodes_; } - // Get the output of funcgraph, usually there is only one output node, In the control flow, there are - // multiple branch outputs, there will be multiple output nodes. - std::vector FetchAllBranchOutputs(const FuncGraphPtr &func_graph); - - // Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node - // which device address belongs, so switch actor need to get all the possible nodes. - std::set FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output); - - // Get the device context corresponding to the value node. - DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node); - - // Get the branch id corresponding to call node. - int GetBranchIDByCallNode(const AnfNodePtr &call_node); - - // Get the number of calls to funcgraph - size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph); - - // Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node - // which device address belongs, so gather actor need to get all the possible nodes. - std::vector GetBackendInputByParameter(const AnfNodePtr ¶meter); - - // Check whether there is a call node in the front input nodes of the kernel graph. - bool IsCallInputKernelGraph(const KernelGraphPtr &graph); - - // Check whether the kernel actor belongs to the root graph. - // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should - // be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the - // output switch actor. - bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel); - - // Get the backend node corresponding to the weight node in the subgraph. - AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node); - - KernelWithIndex GetBackendKernelByFrontKernel(const KernelWithIndex &front_node_with_index) { - return front_to_backend_kernels_[front_node_with_index].first; - } - - AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node); - KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index); + // Fetch all funcgraphs that the call node may call. + const std::set &FetchFuncGraphbyCallNode(const AnfNodePtr &control_node); + // Fetch the branch id corresponding to funcgraph. + int FetchBranchIDByCallNode(const AnfNodePtr &call_node); + // Fetch the funcgraph which the kernel belongs. FuncGraphPtr FetchKernelGraphByFrontNode(const AnfNodePtr &kernel); + // Fetch the backend kernel of front node. + KernelWithIndex FetchBackendNodeByFrontNode(const KernelWithIndex &node_with_index); private: friend class GraphScheduler; @@ -160,134 +114,120 @@ class ControlNodeParser { // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for // them separately during initialization. // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to. - void FetchFrontValueNode(const std::vector &control_nodes, const std::vector &graphs, - const std::vector &device_contexts); - // Create branch id for all subgraphs in the control flow. - void CreateBranchIDForFuncGraph(const std::vector &control_nodes); - // Find all value nodes in the switch recursively. - void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector *value_nodes); - // Fetch all the relationships between front parameters and backend parameters.The front parameters + void FetchFrontValueNode(); + // Create branch id for all call node in the control flow. + void CreateBranchIDForCallNode(const std::vector &control_nodes); + + // Parse all the relationships between front parameters and backend parameters.The front parameters // include two parts: // 1. The parameter from kernel graph. // 2. The parameter from control nodes. - void FetchFrontToBackendParameter(const std::vector &graphs, - const std::vector &device_contexts, - const RealToFormalNode &real_to_formal_front_parameters, - const RealToFormalNode &formal_to_real_front_parameters); - // Get the relationship between the front and backend of the executable kernel in all kernel graphs. - void FetchFrontToBackendKernel(const std::vector &graphs, - const std::vector &device_contexts); - // Get inputs of control node which come from the host actor. These inputs generally come from the partial - // nodes and call nodes of the root funcgraph. - std::vector FetchControlNodeParameter(const std::vector &control_nodes, - DeviceContext *device_context); - // Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node, - // and the input of the call node is the input parameter of the corresponding funcgraph. - void FetchFuncGraphToParameter(const std::vector &control_nodes); - // Get all the front weight parameters related to the weight in the host parameter. - void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters); + void ParseFrontToBackendParameter(const std::vector &graphs, + const std::vector &device_contexts); // The relationship between front parameters indicates that the parameter is directly used as the input of the // funcgraph. There are two situations: // 1. The parameter is used as the input of the call node, // 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the // subsequent call node. - void FetchFrontToFrontParameter(const std::vector &control_nodes, - std::unordered_map> *front_to_front_parameter); - // Get the number of calls to all subgraphs in the whole funcgraph. - void FetchFuncGraphCallNum(const std::vector &control_nodes); + void ParseFormalToRealParameter(const std::vector &control_nodes); + // Recursively get all the real parameters corresponding to the formal parameters. + void ParseAllRealParameterByFormalParameter(const AnfNodePtr &formal_parameter, + const FormalToRealParameter &formal_to_real_parameters, + std::set *total_real_parameters, + std::set *invalid_real_parameter); + + // Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be + // copied between different device memories. The analysis steps: + // 1. Get the device context of the funcgraph parameter according to the device type of the kernel in the funcgraph. + // 2. Determine the type of device context output by funcgraph according to the call relationship of funcgrpah. + void ParseDeviceContext(const std::vector &control_nodes, + const std::vector &kernel_graphs, + const std::vector &device_contexts, + const FuncGraphToKernelGraph &func_graph_to_kernel_graphs); + void ParseDeviceContextForFuncGraph(const std::vector &control_nodes, + const std::vector &kernel_graphs, + const std::vector &device_contexts, + const FuncGraphToKernelGraph &func_graph_to_kernel_graphs); + void ParseDeviceContextForControlNode(const DeviceContext *default_context); + + // In the actor model, when the funcgraph comes to an end temporarily, the exit of the funcgraph needs to notify + // the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all + // actors in the funcgraph that need to send control messages to the entrance. + // These node are control nodes without control node input in the topological sort of the funcgraph. + void ParseFirstControlNodeForFuncGraph(const std::vector &control_nodes); + // Parse all funcgraphs that call nodes may call. + void ParseCallNodeToFuncGraph(const std::vector &control_nodes); + + // Get the relationship between the front and backend of the executable kernel in all kernel graphs. + void FetchFrontToBackendKernel(const std::vector &graphs, + const std::vector &device_contexts); + void FetchFrontNodeToKernelGraph(const std::vector &graphs); + // nodes and call nodes of the root funcgraph. + void FetchControlNodeParameter(const std::vector &control_nodes); + // Get all the front weight parameters related to the weight in the host parameter. + void FetchHostParameterToWeight(); // Get all the kernel graphs where the input node has a call node. void FetchCallInputKernelGraph(const std::vector &graphs, const std::vector &device_contexts); - // Get the relationship of all real and formal nodes in the whole funcgraph. - void FetchBackendInputNode(const std::vector &graphs, - const std::vector &device_contexts, - const RealToFormalNode &real_to_formal_front_parameters, - const RealToFormalNode &formal_to_real_front_parameters); - // Get the relationship of all real and formal parameters in the whole funcgraph. - void FetchBackendParameterNode(const std::vector &graphs, - const std::vector &device_contexts, - const RealToFormalNode &real_to_formal_front_parameters, - const RealToFormalNode &formal_to_real_front_parameters, - FrontToBackendNodeWithContext *front_to_backend_parameters); - // Get all possible input node of real parameter. - void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter, - const FrontToBackendNodeWithContext &front_to_backend_parameters); - // Recursive interface, get all Backend node by front_output. - void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set *call_nodes, - std::set *switch_nodes, std::set *results); - // Get the dependency between kernel and call node in auto monad. void FetchAutoMonadNode(const std::vector &control_nodes); + // Fetch the formal parameter in root graph by parameters in subgraph. + AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node); + + // In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded + // in this map. + FuncGraphToKernelGraph func_graph_to_kernel_graphs_; + // The kernel graph to which the front node belongs after the funcgraph is cut. + FrontNodeToKernelGraph front_node_to_kernel_graph_; + // The front to backend parameters is used to build and link the host data source actor in the control flow scenario. FrontToBackendNodeWithContext front_to_backend_parameters_; - - // The relationship between all real parameters and formal parameters in the entire func_graph. - // In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output - // actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor. - // When the control actor receives the device address, it can find the corresponding input node. - std::unordered_map> formal_to_real_parameters_; - // Relationship between the front and backend of the executable kernel in all kernel graphs. FrontToBackendKernelWithContext front_to_backend_kernels_; - // The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize - // the input node of gather. - FuncGraphToParameter func_graph_to_parameters_; + // Relationship between formal parameters and real parameters. + FormalToRealParameter formal_to_real_parameters_; + RealToFormalParameter real_to_formal_parameters_; - // The relationship between the valuenode inputs of the call node and the backend parameter - std::map> call_node_to_backend_parameters_; - - // Branch id of call node. + // Branch id of funcgraph. // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to - // different places. Therefore, a branch id is created for each call node. When funcgraph is called, the branch - // id needs to be sent to the entrance actor corresponding to the funcgraph, and then send the branch id to its - // output switch actor. + // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch + // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id + // to its output switch actor. std::unordered_map call_node_to_branch_id_; - + std::unordered_map> call_node_to_func_graphs_; // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph. // When initializing the weights, all related weights need to be recorded as the same device tensor. HostParameterToWeight host_parameter_to_weights_; std::unordered_map sub_front_node_to_root_front_node_; - // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the // input of the control node. NodeWithDeviceContext front_value_nodes_; - // The front value node saves all parameters that are not in the kernel graph. These nodes are generally the - // output of subgraph, or the switch condition node. - NodeWithDeviceContext front_parameters_; // Parameters of control node which come from the host actor. std::vector control_node_parameters_; - // The number of calls to func_graph. - std::unordered_map func_graph_to_call_num_; - // In control flow, funcgraph will be divided into multiple kernel graphs. This map records this correspondence. - FuncGraphToKernelGraph func_graph_to_kernel_graphs_; - // In control flow, if there is a call node in the funcgraph, it means that when the funcgraph executes to the call, - // it needs to jump to another funcgraph. At this time, the funcgraph needs to process other real parameters, so - // these nodes need to send control arrows to the entrance actor to tell it to continue processing other parameters, - // these nodes are recorded in this map. - std::unordered_map> func_graph_to_first_control_nodes_; // The kernel graph of call exists in the front input node. // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor // is created for kernel graph which has a call input. std::unordered_map call_input_kernel_graphs_; + // The dependency between kernel and call node in auto monad. + std::unordered_map kernel_to_call_nodes_; + // Control nodes without a control node input in the topological sorting of funcgraph. + std::unordered_map> func_graph_to_first_control_nodes_; + + // In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding + // to the type needs to be parsed in advance so that it can add some copy operation in the scheduler. + // 1. The device context type of the formal parameters of funcgraph. + std::unordered_map> func_graph_to_device_contexts_; + // 2. The device context type of the control node inputs. + std::unordered_map> control_node_to_device_contexts_; + + // Is control flow enable. + bool is_inited_{false}; + // Root funcgraph and its parameters. FuncGraphPtr root_func_graph_; std::vector root_graph_parameters_; - - // The dependency between kernel and call node in auto monad. - std::unordered_map kernel_to_call_nodes_; - // Call node will call different funcgraphs according to the input partial node, and this relationship is recorded - // in this map. - std::unordered_map> call_node_to_func_graphs_; - // In heterogeneous scenarios, different formal parameters of funcgraph will have different contexts. In order to - // ensure that there is no copy actor between control actors, the device context type corresponding to each formal - // parameter needs to be derived in the parser and recorded in this map. - std::unordered_map> func_graph_to_device_contexts_; - std::unordered_map> control_node_to_device_contexts_; - // Record which kernel graph the front node is in. - FrontNodeToKernelGraph front_node_to_kernel_graph_; - bool is_inited_{false}; }; using ControlNodeParserPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index 673f3b07933..a067e04e31c 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -127,7 +127,8 @@ std::vector ControlNodeScheduler::BuildGatherActor(const GraphCo // The gather actor corresponding to a call node needs to set the branch id. if (AnfAlgo::IsCallNode(control_node)) { - gather_actor->output_branch_id_ = graph_compiler_info.control_node_parser_->GetBranchIDByCallNode(control_node); + gather_actor->output_branch_id_ = + graph_compiler_info.control_node_parser_->FetchBranchIDByCallNode(control_node); } } } @@ -404,7 +405,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto exit_actor = dynamic_cast(actor); - size_t branch_id = parser->GetBranchIDByCallNode(from_node); + size_t branch_id = parser->FetchBranchIDByCallNode(from_node); LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second, branch_id); } diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 056cb941db1..b5fa9f01aa2 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -92,6 +92,29 @@ std::vector CollectActors(const ActorSet *actor_set) { if (actor_set->output_actor_ != nullptr) { (void)actors.emplace_back(static_cast(actor_set->output_actor_)); } + if (actor_set->control_actors_ != nullptr) { + const auto &control_actor_set = actor_set->control_actors_; + for (auto &switch_actor : control_actor_set->switch_actors_) { + MS_EXCEPTION_IF_NULL(switch_actor); + (void)actors.emplace_back(static_cast(switch_actor)); + } + for (auto &gather_actor : control_actor_set->gather_actors_) { + MS_EXCEPTION_IF_NULL(gather_actor); + (void)actors.emplace_back(static_cast(gather_actor)); + } + for (auto &entrance_actor : control_actor_set->entrance_actors_) { + MS_EXCEPTION_IF_NULL(entrance_actor); + (void)actors.emplace_back(static_cast(entrance_actor)); + } + for (auto &exit_actor : control_actor_set->exit_actors_) { + MS_EXCEPTION_IF_NULL(exit_actor); + (void)actors.emplace_back(static_cast(exit_actor)); + } + for (auto &stack_actor : control_actor_set->stack_actors_) { + MS_EXCEPTION_IF_NULL(stack_actor); + (void)actors.emplace_back(static_cast(stack_actor)); + } + } return actors; } @@ -487,6 +510,8 @@ std::vector GraphScheduler::BuildDataSourceActor(const Graph (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node); (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context); (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position); + // In control flow, need to rely on the front node to find the location of the corresponding real parameter. + (void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position); (void)front_node_position_temp_map.emplace(front_node, data_node_position); data_node_position++; } @@ -525,7 +550,7 @@ std::vector GraphScheduler::BuildDataSourceActor(const Graph continue; } auto backend_iter = front_to_backend_parameter.find(parameter); - if (backend_iter == front_to_backend_parameter.end()) { + if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) { MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter); } @@ -538,15 +563,20 @@ std::vector GraphScheduler::BuildDataSourceActor(const Graph (void)data_source_actors.emplace_back(host_queue_ds_actor); } - const auto &backend_node = backend_iter->second.first; + if (host_queue_ds_actor->data_node_position_map_.find(parameter) != + host_queue_ds_actor->data_node_position_map_.end()) { + continue; + } + + const auto &backend_node = backend_iter->second.begin()->first; auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node); if (iter != host_queue_ds_actor->data_nodes_.end()) { (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, iter - host_queue_ds_actor->data_nodes_.begin()); } else { (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size()); - (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first); - (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second); + (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first); + (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second); } } @@ -1297,15 +1327,17 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun (void)no_output_actors.emplace_back(super_actor.get()); } } - for (auto &kernel_actor : actor_set->kernel_actors_) { - // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor. - if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) && - parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) { - MS_EXCEPTION_IF_NULL(kernel_actor->kernel_); - MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes."; - (void)no_output_actors.emplace_back(kernel_actor.get()); + + // In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count. + if (!parser->IsInited()) { + for (auto &kernel_actor : actor_set->kernel_actors_) { + // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor. + if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) { + (void)no_output_actors.emplace_back(kernel_actor.get()); + } } } + for (auto &data_actor : actor_set->data_source_actors_) { if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) { (void)no_output_actors.emplace_back(data_actor.get()); @@ -1332,7 +1364,9 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info) { - if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) { + if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep || + (graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) { + // In control flow, the exit actor of the root graph sends output data to the output actor. return; } MS_EXCEPTION_IF_NULL(to_actor); @@ -1706,6 +1740,7 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf DumpCopyActors(actor_set->copy_actors_, ofs); DumpLoopCountActor(actor_set->loop_count_actor_, ofs); DumpOutputActor(actor_set->output_actor_, ofs); + DumpControlActors(actor_set->control_actors_, ofs); } void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const { diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 904fb7ce68c..71f0c24eaac 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -379,6 +379,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { // Compile root graph. graph_id_to_device_context_.clear(); + func_graph_to_kernel_graph_ids_.clear(); control_nodes_.clear(); auto subgraph_need_compile = CompileGraph(root_graph); @@ -476,6 +477,10 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu } graph_id_to_device_context_[graph_id] = device_context; + + const auto &func_graph = segment->nodes_[0]->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + func_graph_to_kernel_graph_ids_[func_graph].emplace_back(graph_id); } else { // Compile the cut node. auto cut_node = segment->nodes_[0]; @@ -971,8 +976,18 @@ std::unique_ptr MindRTBackend::ConstructGraphCompilerInfo(con (void)name.append("_").append(std::to_string(graph_id_to_context.first)); } + FuncGraphToKernelGraph func_graph_to_kernel_graphs; + for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) { + const auto &func_graph = func_graph_to_kernel_graph_ids.first; + for (const auto &graph_id : func_graph_to_kernel_graph_ids.second) { + const auto &kernel_graph = graph_compiler_->Fetch(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graph); + } + } + auto parser = std::make_shared(); - parser->Parse(control_nodes_, graphs, device_contexts, root_graph); + parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs); runtime::KernelMapPosition outputs_order; size_t outputs_num = 0; diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 22b70956ae5..e42ecf757aa 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -42,6 +42,7 @@ using ActorInfo = runtime::ActorInfo; using GraphCompiler = runtime::GraphCompiler; using GraphCompilerInfo = runtime::GraphCompilerInfo; using ControlNodeParser = runtime::ControlNodeParser; +using FuncGraphToKernelGraph = runtime::FuncGraphToKernelGraph; using ControlNodeParserPtr = runtime::ControlNodeParserPtr; using KernelWithIndex = session::KernelWithIndex; @@ -157,6 +158,8 @@ class MindRTBackend : public Backend { // node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to // the corresponding device_context. std::map graph_id_to_device_context_; + // Funcgraph will be cut into multiple kernel graphs, and the map is used to save the correspondence. + std::map> func_graph_to_kernel_graph_ids_; std::map graph_info_to_device_context_; std::vector control_nodes_;