From 5b24178bc30cccb7bb154bdb3a4ac1184ea2465e Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Fri, 25 Jun 2021 18:57:30 +0800 Subject: [PATCH] fix funcgraph recursive --- .../runtime/framework/actor/gather_actor.cc | 66 +- .../runtime/framework/actor/gather_actor.h | 5 +- .../runtime/framework/control_node_parser.cc | 925 ++++++++++++++---- .../runtime/framework/control_node_parser.h | 122 ++- .../runtime/framework/graph_scheduler.cc | 5 +- 5 files changed, 806 insertions(+), 317 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc index 9d51163ec91..61662dafded 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc @@ -63,70 +63,14 @@ void GatherActor::RunOpData(OpData *input_data, OpContext &origin_parameters_order, - const FrontToBackendNodeWithContext &front_to_backend_parameters, - const FuncGraphToParameter &func_graph_to_parameters, - const std::unordered_map &front_to_backend_kernel) { - const auto func_iter = func_graph_to_parameters.find(func_graph); - if (func_iter == func_graph_to_parameters.end()) { - return; - } - - std::vector graph_inputs; +void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) { for (const auto &input : func_graph->get_inputs()) { // Monad input would not send to gather actor. - if (!HasAbstractMonad(input)) { - graph_inputs.emplace_back(input); - } - } - - // Collect all backend input node to gather, There are two situations: - // 1. The parameter from the host data source. - // 2. Output the kernel actor. - for (const auto parameters : func_iter->second) { - if (parameters.size() != graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() << " need:" << graph_inputs.size() - << " func_graph:" << func_iter->first->ToString(); - } - - for (size_t i = 0; i < parameters.size(); ++i) { - if (parameters[i]->isa()) { - // Input node is a parameter from host data source actor. - std::vector invalid_inputs; - std::vector front_inputs = - FetchInputNodeByParameter(parameters[i], origin_parameters_order, &invalid_inputs, func_graph_to_parameters); - - for (const auto &front_input : front_inputs) { - const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0); - - if (node_with_index.first->isa()) { - const auto &iter = front_to_backend_parameters.find(parameters[i]); - if (iter == front_to_backend_parameters.end()) { - MS_LOG(EXCEPTION) << "Cannot find backend node of node:" - << AnfAlgo::GetNodeDebugString(node_with_index.first); - } - front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second.first, 0}); - } else { - const auto iter = front_to_backend_kernel.find(node_with_index.first); - if (iter == front_to_backend_kernel.end()) { - MS_LOG(EXCEPTION) << "Cannot find actor of front node:" - << AnfAlgo::GetNodeDebugString(node_with_index.first); - } - front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second, node_with_index.second}); - } - } - } else { - // Input node is a cnode. - const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(parameters[i], 0); - const auto iter = front_to_backend_kernel.find(node_with_index.first); - if (iter == front_to_backend_kernel.end()) { - MS_LOG(EXCEPTION) << "Cannot find backend node of node:" - << AnfAlgo::GetNodeDebugString(node_with_index.first); - } - front_to_backend_parameter_[graph_inputs[i]].push_back({iter->second, node_with_index.second}); - } + if (HasAbstractMonad(input) || + (input->isa() && AnfAlgo::IsParameterWeight(input->cast()))) { + continue; } + front_to_backend_parameter_[input] = parser->GetBackendInputByParameter(input); } } diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h index da21f8192e2..c2aa0bd057a 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h @@ -56,10 +56,7 @@ class GatherActor : public OpActor { friend class GraphScheduler; // Collect the inputs of gather actor. - void FetchBackendInputNode(const FuncGraphPtr &func_graph, const std::vector &origin_parameters_order, - const FrontToBackendNodeWithContext &front_to_backend_parameters, - const FuncGraphToParameter &func_graph_to_parameters, - const std::unordered_map &front_to_backend_kernel); + void FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser); void FetchInputDeviceTensor(OpContext *context); // Check whether satisfy the condition for launch. bool CheckLaunchCondition(OpContext *context) const; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index c3000277b0a..b0690b840d0 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -16,6 +16,7 @@ #include "runtime/framework/control_node_parser.h" #include "runtime/framework/actor/switch_actor.h" +#include "runtime/framework/actor/gather_actor.h" #include "ir/tensor.h" namespace mindspore { @@ -40,8 +41,14 @@ void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector } } +// Check whether the input is a valid parameter. bool CheckValidFuncGraphInput(const AnfNodePtr &node) { - return (!IsPersistentDeviceTensor(node)) && (!HasAbstractMonad(node)); + if (HasAbstractMonad(node)) { + return false; + } else if (node->isa()) { + return !HasAbstractRef(node); + } + return true; } // Get the funcgraph in partial node. @@ -50,7 +57,7 @@ FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) { return GetValueNode(partial_inputs[1]); } -// Get the corresponding relationship between funcgraph and parameters in the switch node. +// Get the relationship between funcgraph and parameters in the switch node. void FetchParameterBySwitchNode(const AnfNodePtr &switch_node, FuncGraphToParameter *graph_to_real_parameters) { const auto &switch_cnode = switch_node->cast(); const auto &switch_inputs = switch_cnode->inputs(); @@ -83,16 +90,23 @@ void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const } auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast()->inputs(); + + // Get the parameter corresponding to each funcgraph in make tuple. for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) { + // Tuple branch is a partial node. const auto &func_graph = GetFuncGraphFromPartial(tuple_inputs[i]); std::vector parameters; const auto &partial_inputs = tuple_inputs[i]->cast()->inputs(); + + // Get inputs in partial node. for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) { if (CheckValidFuncGraphInput(partial_inputs[j])) { parameters.emplace_back(partial_inputs[j]); } } + + // Get inputs in call node. for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) { if (CheckValidFuncGraphInput(call_inputs[j])) { parameters.emplace_back(call_inputs[j]); @@ -100,13 +114,17 @@ void FetchParameterBySwitchLayerNode(const AnfNodePtr &switch_layer_node, const } (*graph_to_real_parameters)[func_graph].emplace_back(parameters); } else if (tuple_inputs[i]->isa() && IsValueNode(tuple_inputs[i])) { + // Tuple branch is a call node. const auto &func_graph = GetValueNode(tuple_inputs[i]); std::vector parameters; + + // Get inputs in call node. for (size_t j = kCallInputStartPos; j < call_inputs.size(); ++j) { if (CheckValidFuncGraphInput(call_inputs[j])) { parameters.emplace_back(call_inputs[j]); } } + (*graph_to_real_parameters)[func_graph].emplace_back(parameters); } } @@ -148,8 +166,221 @@ void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodeP MS_EXCEPTION_IF_NULL(address); AnfAlgo::SetOutputAddr(address, 0, front_node.get()); } + +// Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding +// backend parameter, then recursively find the backend parameters of other front parameters corresponding to the +// front_node. +std::pair FetchBackendNodeByFrontNode( + const AnfNodePtr &front_node, + const std::unordered_map> &real_to_formal_front_parameters, + const std::unordered_map> &formal_to_real_front_parameters, + const std::unordered_map> &front_to_backend_parameter, + std::set *invalid_node) { + // Check whether the front_node has been looked for. + if ((*invalid_node).find(front_node) != (*invalid_node).end()) { + return std::pair(); + } + (*invalid_node).insert(front_node); + + const auto front_to_backend_iter = front_to_backend_parameter.find(front_node); + if (front_to_backend_iter != front_to_backend_parameter.end()) { + return front_to_backend_iter->second; + } + + const auto &real_to_formal_iter = real_to_formal_front_parameters.find(front_node); + if (real_to_formal_iter == real_to_formal_front_parameters.end()) { + return std::pair(); + } + for (const auto &next_node : real_to_formal_iter->second) { + auto banckend_node = + FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters, + front_to_backend_parameter, invalid_node); + if (banckend_node.first != nullptr) { + return banckend_node; + } + } + + const auto &formal_to_real_iter = formal_to_real_front_parameters.find(front_node); + if (formal_to_real_iter == formal_to_real_front_parameters.end()) { + return std::pair(); + } + for (const auto &next_node : formal_to_real_iter->second) { + auto banckend_node = + FetchBackendNodeByFrontNode(next_node, real_to_formal_front_parameters, formal_to_real_front_parameters, + front_to_backend_parameter, invalid_node); + if (banckend_node.first != nullptr) { + return banckend_node; + } + } + return std::pair(); +} + +// Fetch all backend input nodes by parameter for gather actor. +std::vector FetchInputNodeByParameter(const AnfNodePtr ¶meter, + const std::vector &host_ds_parameters, + std::set *invalid_inputs, + const FuncGraphToParameter &graph_to_real_parameters) { + std::vector input_nodes; + + // If the node has been collected, skip it. + if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) { + return input_nodes; + } + + // Record the node which has been collected. + (*invalid_inputs).insert(parameter); + + // If the parameter node is a parameter of host data source actor, return it. + if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) { + input_nodes.emplace_back(parameter); + return input_nodes; + } + + // Check the parameter which send to its funcgraph. + const auto &func_graph = parameter->func_graph(); + if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) { + return input_nodes; + } + + std::vector self_inputs; + for (const auto &input : func_graph->get_inputs()) { + // Monad input need not send to funcgraph. + if (HasAbstractMonad(input) || HasAbstractRef(input)) { + continue; + } + self_inputs.emplace_back(input); + } + + const auto iter = find(self_inputs.begin(), self_inputs.end(), parameter); + if (iter == self_inputs.end()) { + MS_LOG(EXCEPTION) << "Cannot find parameter node:" << AnfAlgo::GetNodeDebugString(parameter); + } + size_t pos = iter - self_inputs.begin(); + + for (const auto parameters : graph_to_real_parameters.at(func_graph)) { + if (parameters.size() != self_inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid input num:" << parameters.size() << " and:" << self_inputs.size() + << " for func_graph:" << func_graph->ToString(); + } + const auto input = parameters[pos]; + if (input->isa()) { + input_nodes.emplace_back(input); + } else if (input->isa()) { + // If input is a parameter, you need to find its input recursively. + auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters); + input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end()); + } + } + return input_nodes; +} + +// Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph +// called by the call node. +std::vector FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector *call_nodes) { + std::vector outputs; + const auto &output = func_graph->output(); + const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0, false, {prim::kPrimTupleGetItem}); + if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) { + return outputs; + } + if (!IsCallNode(real_output.first)) { + outputs.push_back(real_output.first); + return outputs; + } + + (*call_nodes).push_back(real_output.first); + std::vector func_graphs = FetchFuncGraphbyCallNode(real_output.first); + for (const auto &graph : func_graphs) { + auto single_outputs = FetchFuncGraphOutput(graph, call_nodes); + outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end()); + } + return outputs; +} +std::vector FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set *call_nodes, + std::set *switch_nodes); + +// Recursive interface, get all possible output nodes of call node. +std::vector FetchOutputByCallNode(const AnfNodePtr &call_node, std::set *call_nodes, + std::set *switch_nodes) { + std::vector outputs; + if ((*call_nodes).find(call_node) != (*call_nodes).end()) { + return outputs; + } + (*call_nodes).insert(call_node); + + const auto func_graphs = FetchFuncGraphbyCallNode(call_node); + + for (const auto func_graph : func_graphs) { + if (func_graph->output()->isa()) { + outputs.push_back(func_graph->output()); + } else { + std::vector sub_call_nodes; + const std::vector graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes); + for (const auto &graph_output : graph_outputs) { + if (graph_output->isa()) { + outputs.push_back(graph_output); + } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) { + const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes); + outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); + } else if (IsCallNode(graph_output)) { + const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes); + outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); + } else if (graph_output->isa()) { + outputs.emplace_back(graph_output); + } else { + MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output); + } + } + } + } + + return outputs; +} + +// Recursive interface, get all possible output nodes of switch node. +std::vector FetchOutputBySwitchNode(const AnfNodePtr &switch_node, std::set *call_nodes, + std::set *switch_nodes) { + std::vector outputs; + if ((*switch_nodes).find(switch_node) != (*switch_nodes).end()) { + return outputs; + } + (*switch_nodes).insert(switch_node); + + if (!switch_node->isa()) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node); + } + const auto &inputs = switch_node->cast()->inputs(); + if (inputs.size() != kSwitchInputNum) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(switch_node); + } + + for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { + if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimPartial)) { + continue; + } else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { + const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes); + outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); + } else if (IsCallNode(inputs[i])) { + const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes); + outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); + } else { + outputs.emplace_back(inputs[i]); + } + } + + return outputs; +} } // namespace +// Return true if the node has Ref abstract. +bool HasAbstractRef(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + auto &abs = node->abstract(); + return (abs != nullptr) && abs->isa(); +} + bool IsCallNode(const AnfNodePtr &node) { if (!node->isa()) { return false; @@ -159,10 +390,30 @@ bool IsCallNode(const AnfNodePtr &node) { return inputs[0]->isa() || (inputs[0]->isa() && IsValueNode(inputs[0])); } -std::vector FetchFuncGraphbyCallNode(const CNodePtr &node) { - std::vector func_graphs; - const auto &call_inputs = node->inputs(); +std::vector FetchInputsByMakeTuple(const AnfNodePtr &node) { + std::vector parameters; + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + const auto ¶meter = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first; + parameters.emplace_back(parameter); + return parameters; + } + const auto &inputs = node->cast()->inputs(); + for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { + const auto &sub_parameters = FetchInputsByMakeTuple(inputs[i]); + parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); + } + + return parameters; +} + +std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node) { + std::vector func_graphs; + if (!node->isa()) { + return func_graphs; + } + + const auto &call_inputs = node->cast()->inputs(); if (call_inputs[0]->isa()) { const auto &cnode = call_inputs[0]->cast(); const auto &cnode_inputs = cnode->inputs(); @@ -176,9 +427,9 @@ std::vector FetchFuncGraphbyCallNode(const CNodePtr &node) { AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) { const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast()->inputs(); - for (size_t i = 1; i < tuple_inputs.size(); ++i) { + for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) { - func_graphs.emplace_back(GetFuncGraphFromPartial(cnode_inputs[i])); + func_graphs.emplace_back(GetFuncGraphFromPartial(tuple_inputs[i])); } else if (IsValueNode(tuple_inputs[i])) { func_graphs.emplace_back(GetValueNode(tuple_inputs[i])); } @@ -194,6 +445,51 @@ std::vector FetchFuncGraphbyCallNode(const CNodePtr &node) { return func_graphs; } +size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector *call_nodes) { + if (!IsCallNode(node)) { + MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node); + } + if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) { + return 0; + } + (*call_nodes).emplace_back(node); + + const auto &func_graphs = FetchFuncGraphbyCallNode(node); + for (const auto &func_graph : func_graphs) { + const auto &output = func_graph->output(); + const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0); + + if (IsCallNode(real_output.first)) { + size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes); + if (output_num > 0) { + return output_num; + } + } else if (AnfAlgo::CheckPrimitiveType(real_output.first, prim::kPrimMakeTuple)) { + size_t total_num = 0; + const auto &tuple_cnode = real_output.first->cast(); + const auto &inputs = tuple_cnode->inputs(); + size_t i = 1; + for (; i < inputs.size(); ++i) { + if (IsCallNode(inputs[i])) { + size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes); + if (call_output_num == 0) { + break; + } + total_num += call_output_num; + } else { + ++total_num; + } + } + if (i == inputs.size()) { + return total_num; + } + } else { + return 1; + } + } + return 0; +} + FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) { auto front_node = GetFrontNodeByBackendNode(node); @@ -202,7 +498,8 @@ FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) { if (node->isa()) { const auto &cnode = node->cast(); const auto &inputs = cnode->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { + + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { const auto &func_graph = FetchFuncGraphByNode(inputs[i]); if (func_graph != nullptr) { return func_graph; @@ -228,6 +525,18 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) { return kernel_graph->GetFrontAnfByBackendAnf(backend_node); } +AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) { + const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node); + if (front_node != nullptr) { + return front_node; + } + const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node); + if (front_node_with_index.first == nullptr) { + MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node); + } + return front_node_with_index.first; +} + FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) { auto front_node = GetFrontNodeByBackendNode(backend_node); if (front_node == nullptr) { @@ -236,71 +545,84 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) { return front_node->func_graph(); } -std::vector FetchInputNodeByParameter(const AnfNodePtr ¶meter, - const std::vector &host_ds_parameters, - std::vector *invalid_inputs, - const FuncGraphToParameter &graph_to_real_parameters) { - std::vector input_nodes; - - // If the node has been collected, skip it. - if (find((*invalid_inputs).begin(), (*invalid_inputs).end(), parameter) != (*invalid_inputs).end()) { - return input_nodes; - } - - // Record the node which has been collected. - (*invalid_inputs).emplace_back(parameter); - - // If the parameter node is a parameter of host data source actor, return it. - if (find(host_ds_parameters.begin(), host_ds_parameters.end(), parameter) != host_ds_parameters.end()) { - input_nodes.emplace_back(parameter); - return input_nodes; - } - - // Check the parameter which send to its funcgraph. - const auto &func_graph = parameter->func_graph(); - if (graph_to_real_parameters.find(func_graph) == graph_to_real_parameters.end()) { - return input_nodes; - } - - std::vector self_inputs; - for (const auto &input : func_graph->get_inputs()) { - // Monad input need not send to funcgraph. - if (!HasAbstractMonad(input)) { - self_inputs.emplace_back(input); - } - } - - size_t pos = find(self_inputs.begin(), self_inputs.end(), parameter) - self_inputs.begin(); - for (const auto parameters : graph_to_real_parameters.at(func_graph)) { - const auto input = parameters[pos]; - if (input->isa()) { - input_nodes.emplace_back(input); - } else if (input->isa()) { - // If input is a parameter, you need to find its input recursively. - auto inputs = FetchInputNodeByParameter(input, host_ds_parameters, invalid_inputs, graph_to_real_parameters); - input_nodes.insert(input_nodes.end(), inputs.begin(), inputs.end()); - } - } - return input_nodes; -} - void ControlNodeParser::Parse(const std::vector &control_nodes, const std::vector &graphs, const std::vector &device_contexts, const FuncGraphPtr &root_graph) { - FetchFrontToBackendParameterMap(graphs, device_contexts, control_nodes); + root_graph_parameters_ = root_graph->parameters(); - FetchFuncGraphToParameterMap(control_nodes); + CreateBranchIDForFuncGraph(control_nodes); - FetchHostParameterToWeightMap(control_nodes); + FetchFrontToBackendParameter(graphs, device_contexts, control_nodes); - FetchFrontValueNode(graphs, device_contexts); + FetchFuncGraphToParameter(control_nodes); + + FetchHostParameterToWeight(control_nodes); + + FetchFrontValueNode(control_nodes, graphs, device_contexts); + + FetchFrontToBackendKernel(graphs, device_contexts); - // Get inputs of control node which come from the host actor. control_node_parameters_ = FetchControlNodeParameter(control_nodes); + FetchFuncGraphCallNum(control_nodes); + + FetchCallInputKernelGraph(graphs, device_contexts); + + FetchBackendInputNode(); + front_output_nodes_ = FetchAllBranchOutputs(root_graph); } -void ControlNodeParser::FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node, +std::vector ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr ¶meter) { + return formal_to_real_parameters_[parameter]; +} + +std::vector ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) { + std::set call_nodes; + std::set switch_nodes; + return FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes); +} + +int ControlNodeParser::GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + + if (func_graph_to_branch_id_.find(func_graph) == func_graph_to_branch_id_.end()) { + MS_LOG(EXCEPTION) << "Invalid branch id for funcgraph:" << func_graph->ToString(); + } + return func_graph_to_branch_id_[func_graph]; +} + +bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) { + if (call_input_kernel_graphs_.find(graph) == call_input_kernel_graphs_.end()) { + return false; + } + return true; +} + +size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) { + if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) { + MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString(); + } + + return func_graph_to_call_num_[func_graph]; +} + +std::vector ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) { + std::vector call_nodes; + return FetchFuncGraphOutput(func_graph, &call_nodes); +} + +DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node) { + auto iter = std::find_if( + front_value_nodes_.begin(), front_value_nodes_.end(), + [value_node](const auto &front_node_with_context) { return front_node_with_context.first == value_node; }); + + if (iter != front_value_nodes_.end()) { + return iter->second; + } + return nullptr; +} + +void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector *value_nodes) { const auto &cnode = switch_node->cast(); const auto &inputs = cnode->inputs(); @@ -321,7 +643,7 @@ void ControlNodeParser::FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node if (call_inputs.empty() || (!AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch))) { continue; } - FetchValueNodeInSwitchNode(call_inputs[0], value_nodes); + FetchValueNodeBySwitchNode(call_inputs[0], value_nodes); } else if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) { const auto &partial_node = input->cast(); const auto &partial_inputs = partial_node->inputs(); @@ -338,8 +660,40 @@ void ControlNodeParser::FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node } } -void ControlNodeParser::FetchFrontValueNode(const std::vector &graphs, +void ControlNodeParser::FetchFrontValueNode(const std::vector &control_nodes, + const std::vector &graphs, const std::vector &device_contexts) { + for (const auto &control_node : control_nodes) { + CNodePtr cnode = control_node->cast(); + auto inputs = cnode->inputs(); + if (inputs[0]->isa() && IsValueNode(inputs[0])) { + auto func_graph = GetValueNode(inputs[0]); + const auto parameters = func_graph->parameters(); + if (parameters.size() != inputs.size() - kCallInputStartPos) { + MS_LOG(EXCEPTION) << "Invalid parameters num, need:" << parameters.size() + << " has:" << inputs.size() - kCallInputStartPos; + } + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + if (inputs[i]->isa()) { + const auto &node_value = inputs[i]->cast()->value(); + if (!node_value->isa()) { + continue; + } + if (front_to_backend_parameters_.find(parameters[i - kCallInputStartPos]) == + front_to_backend_parameters_.end()) { + MS_LOG(EXCEPTION) << "Cannot find backend parameter for front parameter:" + << AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos]); + } + + const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first; + const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second; + CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context); + front_value_nodes_.push_back({inputs[i], device_context}); + } + } + } + } + for (size_t index = 0; index < graphs.size(); ++index) { const auto &graph = graphs[index]; MS_EXCEPTION_IF_NULL(graph); @@ -359,7 +713,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &g MS_EXCEPTION_IF_NULL(front_output_node); if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) { std::vector value_nodes; - FetchValueNodeInSwitchNode(front_output_node, &value_nodes); + FetchValueNodeBySwitchNode(front_output_node, &value_nodes); for (const auto value_node : value_nodes) { CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]); front_value_nodes_.push_back({value_node, device_contexts[index]}); @@ -370,138 +724,6 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &g } } -AnfNodePtr ControlNodeParser::GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph, - const size_t pos) { - const auto &cnode = call_node->cast(); - const auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "Input of call node is empty, node:" << AnfAlgo::GetNodeDebugString(call_node); - } - - if (inputs[0]->isa()) { - // The first input of call node is a switch node. - if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) { - if (inputs.size() != kSwitchInputNum) { - MS_LOG(EXCEPTION) << "Invalid input num of switch node, num:" << inputs.size(); - } - - auto switch_cnode = inputs[0]->cast(); - auto switch_inputs = switch_cnode->inputs(); - auto true_partial_cnode = switch_inputs[kSwitchTrueBranchPos]->cast(); - auto partial_inputs = true_partial_cnode->inputs(); - if (partial_inputs.size() <= kPartialFuncGraphPos) { - MS_LOG(EXCEPTION) << "Invalid input num of switch node, num:" << inputs.size(); - } - auto true_func_graph = GetValueNode(partial_inputs[kPartialFuncGraphPos]); - - // If the funcgraph is not in true branch, it must be in false branch. - if (true_func_graph != func_graph) { - auto false_partial_cnode = switch_inputs[kSwitchFalseBranchPos]->cast(); - partial_inputs = false_partial_cnode->inputs(); - } - if (pos + kPartialInputStartPos >= partial_inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid input pos:" << pos << " node:" << AnfAlgo::GetNodeDebugString(call_node); - } - return partial_inputs[pos + kPartialInputStartPos]; - } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { - if (inputs.size() != kSwitchLayerInputNum) { - MS_LOG(EXCEPTION) << "Invalid input num of switchlayer node, num:" << inputs.size(); - } - - // The first input of call node is a switchlayer node. - auto switch_layer_cnode = inputs[0]->cast(); - const auto &make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchPos); - const auto &make_tuple_cnode = make_tuple_node->cast(); - const auto &tuple_inputs = make_tuple_cnode->inputs(); - - for (const auto &node : tuple_inputs) { - // tuple input is a value node of funcgraph. - if (node->isa() && IsValueNode(node)) { - auto branch_graph = GetValueNode(node); - if (branch_graph == func_graph) { - if (pos + kCallInputStartPos >= inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid input pos:" << pos << " node:" << AnfAlgo::GetNodeDebugString(call_node); - } - return inputs[pos + kCallInputStartPos]; - } - } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { - // tuple input is a partial, it means that part of the input of funcgraph is on partial and part on call. - auto partial_cnode = node->cast(); - auto partial_input = partial_cnode->inputs(); - if (func_graph == GetValueNode(partial_input[kPartialFuncGraphPos])) { - if (pos + kPartialInputStartPos >= partial_input.size()) { - return inputs[pos + kPartialInputStartPos - partial_input.size()]; - } else { - return partial_input[pos + kPartialInputStartPos]; - } - } - } - } - } else { - MS_LOG(EXCEPTION) << "Invalid input node:" << AnfAlgo::GetNodeDebugString(inputs[0]); - } - } else { - // The first input of funcgraph is a value node of funcgraph. - if (pos + kCallInputStartPos >= inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid input pos:" << pos << " node:" << AnfAlgo::GetNodeDebugString(call_node); - } - return inputs[pos + kCallInputStartPos]; - } - return nullptr; -} - -std::vector ControlNodeParser::FetchFuncGraphOutput(const FuncGraphPtr &func_graph, - std::vector *call_nodes) { - std::vector outputs; - const auto &output = func_graph->output(); - const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0, false, {prim::kPrimTupleGetItem}); - if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) { - return outputs; - } - if (!IsCallNode(real_output.first)) { - outputs.push_back(real_output.first); - return outputs; - } - - (*call_nodes).push_back(real_output.first); - const auto &call_cnode = real_output.first->cast(); - std::vector func_graphs = FetchFuncGraphbyCallNode(call_cnode); - for (const auto &graph : func_graphs) { - auto single_outputs = FetchFuncGraphOutput(graph, call_nodes); - outputs.insert(outputs.end(), single_outputs.begin(), single_outputs.end()); - } - return outputs; -} - -std::pair ControlNodeParser::FetchBackendNodeByFrontNode( - const AnfNodePtr &front_node, const std::unordered_map> &front_to_front_parameter, - const std::unordered_map> &front_to_backend_parameter, - std::set *invalid_node) { - // Check whether the front_node has been looked for. - if ((*invalid_node).find(front_node) != (*invalid_node).end()) { - return std::pair(); - } - (*invalid_node).insert(front_node); - - const auto front_to_backend_iter = front_to_backend_parameter.find(front_node); - if (front_to_backend_iter != front_to_backend_parameter.end()) { - return front_to_backend_iter->second; - } - - const auto &front_to_front_iter = front_to_front_parameter.find(front_node); - if (front_to_front_iter == front_to_front_parameter.end()) { - return std::pair(); - } - for (const auto &next_node : front_to_front_iter->second) { - auto banckend_node = - FetchBackendNodeByFrontNode(next_node, front_to_front_parameter, front_to_backend_parameter, invalid_node); - if (banckend_node.first != nullptr) { - return banckend_node; - } - } - return std::pair(); -} - void ControlNodeParser::FetchFrontToFrontParameterMap( const std::vector &control_nodes, std::unordered_map> *front_to_front_parameter) { @@ -546,7 +768,8 @@ void ControlNodeParser::FetchFrontToFrontParameterMap( // call_inputs will be the input of these funcgraphs. const auto &tuple_node = switch_inputs[kSwitchLayerBranchPos]->cast(); const auto &tuple_inputs = tuple_node->inputs(); - for (const auto &input : tuple_inputs) { + for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { + const auto &input = tuple_inputs[i]; if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimPartial)) { partial_input_parse(input, call_inputs); } else { @@ -609,14 +832,105 @@ std::vector ControlNodeParser::FetchControlNodeParameter(const std:: return parameters; } -std::vector ControlNodeParser::FetchAllBranchOutputs(const FuncGraphPtr &func_graph) { - std::vector call_nodes; - return FetchFuncGraphOutput(func_graph, &call_nodes); +void ControlNodeParser::FetchFuncGraphCallNum(const std::vector &control_nodes) { + for (const auto &control_node : control_nodes) { + if (IsCallNode(control_node)) { + const auto &func_graphs = FetchFuncGraphbyCallNode(control_node); + for (const auto &func_graph : func_graphs) { + MS_EXCEPTION_IF_NULL(func_graph); + if (func_graph->output()->isa()) { + continue; + } + + if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) { + func_graph_to_call_num_[func_graph] = 1; + } else { + func_graph_to_call_num_[func_graph]++; + } + } + } + } } -void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vector &graphs, - const std::vector &device_contexts, - const std::vector &control_nodes) { +void ControlNodeParser::FetchCallInputKernelGraph(const std::vector &graphs, + const std::vector &device_contexts) { + for (size_t i = 0; i < graphs.size(); ++i) { + const auto &graph = graphs[i]; + const auto &device_context = device_contexts[i]; + + const auto inputs = graph->parameters(); + for (const auto &input : inputs) { + const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input); + if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) { + call_input_kernel_graphs_[graph] = device_context; + break; + } + } + } +} + +void ControlNodeParser::CreateBranchIDForFuncGraph(const std::vector &control_nodes) { + int branch_id = 0; + + for (const auto &control_node : control_nodes) { + // Root funcgraph does not need to create a gather actor. + if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { + const auto &cnode = control_node->cast(); + const auto &inputs = cnode->inputs(); + // If the output of funcgraph is a value node, no need to create gather actor. + if (inputs[kReturnInputPos]->isa()) { + continue; + } + + auto func_graph = control_node->func_graph(); + func_graph_to_branch_id_[func_graph] = branch_id++; + } + } +} + +std::vector FetchInputParameterbyControlNode(const AnfNodePtr &node, std::set *switch_nodes, + std::set *call_nodes) { + std::vector parameters; + + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { + if ((*switch_nodes).find(node) != (*switch_nodes).end()) { + return parameters; + } + (*switch_nodes).insert(node); + + const auto &cnode = node->cast(); + const auto &inputs = cnode->inputs(); + if (inputs.size() != kSwitchInputNum) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(node); + } + + for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { + if (inputs[i]->isa()) { + parameters.emplace_back(inputs[i]); + } else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { + const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes); + parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); + } + } + } else if (IsCallNode(node)) { + if ((*call_nodes).find(node) != (*call_nodes).end()) { + return parameters; + } + (*call_nodes).insert(node); + + const auto &func_graphs = FetchFuncGraphbyCallNode(node); + for (const auto &func_graph : func_graphs) { + if (func_graph->output()->isa()) { + parameters.emplace_back(func_graph->output()); + } + } + } + return parameters; +} + +void ControlNodeParser::FetchFrontToBackendParameter(const std::vector &graphs, + const std::vector &device_contexts, + const std::vector &control_nodes) { if (graphs.size() != device_contexts.size()) { MS_LOG(EXCEPTION) << "Graph num is not equal to device context num."; } @@ -627,6 +941,7 @@ void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vectorparameters()) { auto front_node = graph->GetFrontAnfByBackendAnf(parameter); + if (front_node != nullptr && front_node->isa() && front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) { front_to_backend_parameters_[front_node] = {parameter, device_context}; @@ -634,25 +949,57 @@ void ControlNodeParser::FetchFrontToBackendParameterMap(const std::vectorparameters()) { + const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(parameter); + + if (internal_front_node.first != nullptr) { + std::set call_nodes; + std::set switch_nodes; + const auto &front_paramters = + FetchInputParameterbyControlNode(internal_front_node.first, &switch_nodes, &call_nodes); + for (const auto &front_paramter : front_paramters) { + if (front_to_backend_parameters_.find(front_paramter) == front_to_backend_parameters_.end()) { + front_to_backend_parameters_[front_paramter] = {parameter, device_context}; + } + } + } + } + } + // Fetch the mapping relationship between front parameters and backend parameters in the control nodes. First - // fetch the mapping relationship of the frontparameter. When the input of the call node or the partial node + // fetch the mapping relationship of the front parameter. When the input of the call node or the partial node // is a parameter node, it means that the parameter is directly transmitted. If a parameter does not have a // corresponding backend node, then recursively find whether the front parameter corresponding to the parameter // has one. - std::unordered_map> front_to_front_parameter; - FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter); + std::unordered_map> real_to_formal_front_parameters; + FetchFrontToFrontParameterMap(control_nodes, &real_to_formal_front_parameters); - for (const auto &front_pair : front_to_front_parameter) { + std::unordered_map> formal_to_real_front_parameters; + for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) { + for (const auto formal_parameter : real_to_formal_front_parameter.second) { + formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first); + } + } + + for (const auto &front_pair : real_to_formal_front_parameters) { std::set invalid_node; - const auto &backend_node = FetchBackendNodeByFrontNode(front_pair.first, front_to_front_parameter, - front_to_backend_parameters_, &invalid_node); + const auto &backend_node = + FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters, + front_to_backend_parameters_, &invalid_node); if (backend_node.first != nullptr) { - front_to_backend_parameters_[front_pair.first] = backend_node; + if (front_to_backend_parameters_.find(front_pair.first) == front_to_backend_parameters_.end()) { + front_to_backend_parameters_[front_pair.first] = backend_node; + } } } } -void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector &control_nodes) { +void ControlNodeParser::FetchHostParameterToWeight(const std::vector &control_nodes) { std::unordered_map> front_to_front_parameter; FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter); @@ -664,7 +1011,7 @@ void ControlNodeParser::FetchHostParameterToWeightMap(const std::vector &control_nodes) { +void ControlNodeParser::FetchFuncGraphToParameter(const std::vector &control_nodes) { for (const auto &control_node : control_nodes) { const auto &cnode = control_node->cast(); const auto &inputs = cnode->inputs(); @@ -698,5 +1045,157 @@ void ControlNodeParser::FetchFuncGraphToParameterMap(const std::vector &graphs, + const std::vector &device_contexts) { + for (size_t i = 0; i < graphs.size(); ++i) { + const auto &graph = graphs[i]; + const auto &device_context = device_contexts[i]; + MS_EXCEPTION_IF_NULL(graph); + auto execution_order = graph->execution_order(); + for (auto &kernel : execution_order) { + if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) { + auto front_node = graph->GetFrontAnfByBackendAnf(kernel); + if (front_node != nullptr) { + front_to_backend_kernels_[front_node] = {kernel, device_context}; + } + } + } + } +} + +std::vector ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, + std::set *call_nodes, + std::set *switch_nodes) { + std::vector backend_outputs; + if (front_output->isa()) { + backend_outputs.push_back({front_output, 0}); + } else if (front_output->isa()) { + // Output is a parameter. + const auto iter = formal_to_real_parameters_.find(front_output); + + if (iter != formal_to_real_parameters_.end()) { + backend_outputs.insert(backend_outputs.end(), iter->second.begin(), iter->second.end()); + } else { + MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output); + } + } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimSwitch)) { + // Output is a switch. + const auto &switch_outputs = FetchOutputBySwitchNode(front_output, call_nodes, switch_nodes); + + for (const auto &switch_output : switch_outputs) { + const auto outputs = FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes); + backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end()); + } + } else if (IsCallNode(front_output)) { + // Output is a call. + const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes); + + for (const auto &call_output : call_outputs) { + const auto outputs = FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes); + backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end()); + } + } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimMakeTuple)) { + // Output is a make tuple. + const auto &cnode = front_output->cast(); + const auto &inputs = cnode->inputs(); + + for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { + const auto outputs = FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes); + backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end()); + } + } else if (front_output->isa()) { + // Output is a kernel. + const auto iter = front_to_backend_kernels_.find(front_output); + + if (iter != front_to_backend_kernels_.end()) { + const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(iter->second.first, 0); + backend_outputs.emplace_back(output_with_index); + } else { + MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output); + } + } else { + MS_LOG(EXCEPTION) << "Invalid front node:" << AnfAlgo::GetNodeDebugString(front_output); + } + + return backend_outputs; +} + +void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, + const AnfNodePtr &formal_parameter) { + if (real_parameter->isa()) { + // Input node is a parameter from host data source actor. + std::set invalid_inputs; + std::vector front_inputs = + FetchInputNodeByParameter(real_parameter, root_graph_parameters_, &invalid_inputs, func_graph_to_parameters_); + + for (const auto &front_input : front_inputs) { + const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0); + + if (node_with_index.first->isa()) { + const auto &iter = front_to_backend_parameters_.find(real_parameter); + if (iter == front_to_backend_parameters_.end()) { + MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); + continue; + } + formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0}); + } else { + const auto iter = front_to_backend_kernels_.find(node_with_index.first); + if (iter == front_to_backend_kernels_.end()) { + MS_LOG(EXCEPTION) << "Cannot find actor of front node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); + } + formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, node_with_index.second}); + } + } + } else if (real_parameter->isa()) { + formal_to_real_parameters_[formal_parameter].push_back({real_parameter, 0}); + } else if (IsCallNode(real_parameter)) { + const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); + for (const auto func_graph : func_graphs) { + FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter); + } + } else { + // Input node is a cnode. + const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(real_parameter, 0); + const auto iter = front_to_backend_kernels_.find(node_with_index.first); + if (iter == front_to_backend_kernels_.end()) { + MS_LOG(EXCEPTION) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); + } + formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, node_with_index.second}); + } +} + +void ControlNodeParser::FetchBackendInputNode() { + for (const auto &func_graph_to_parameters : func_graph_to_parameters_) { + const auto &func_graph = func_graph_to_parameters.first; + std::vector graph_inputs; + for (const auto &input : func_graph->get_inputs()) { + // Monad input would not send to gather actor. + if (HasAbstractMonad(input) || (input->isa() && HasAbstractRef(input))) { + continue; + } + graph_inputs.emplace_back(input); + } + + // Collect all backend input node to gather, There are two situations: + // 1. The parameter from the host data source. + // 2. Output the kernel actor. + for (const auto parameters : func_graph_to_parameters.second) { + if (parameters.size() != graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Parameters num is invalid, current:" << parameters.size() + << " need:" << graph_inputs.size() << " func_graph:" << func_graph->ToString(); + } + + for (size_t i = 0; i < parameters.size(); ++i) { + FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i]); + } + } + } + + for (const auto front_to_backend_parameters : front_to_backend_parameters_) { + formal_to_real_parameters_[front_to_backend_parameters.first].push_back( + {front_to_backend_parameters.second.first, 0}); + } +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 330697949ad..900f4218d59 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -47,22 +47,32 @@ using NodeWithDeviceContext = std::vector // 2. First input of node is a funcgraph value node. bool IsCallNode(const AnfNodePtr &node); +// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph, +// it is determined whether it is a weight. +bool HasAbstractRef(const AnfNodePtr &node); + +// Recursive interface, get the funcgraph which the node belongs, if the node has a front node, return the funcgraph +// which the front node belongs, if not, find the funcgraph which the input of the node belongs. FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node); +// Recursive interface, get the number of output nodes of funcgraph called by call node. +size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector *call_nodes); + // Get front node by backend node. AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node); +// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the +// corresponding cnode. +AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph); + // Get the funcgraph to which the node belongs. FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node); // Find all funcgraphs that the call node will call. -std::vector FetchFuncGraphbyCallNode(const CNodePtr &node); +std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node); -// Fetch all backend input nodes by parameter for gather actor. -std::vector FetchInputNodeByParameter(const AnfNodePtr ¶meter, - const std::vector &host_ds_parameters, - std::vector *invalid_inputs, - const FuncGraphToParameter &graph_to_real_parameters); +// Recursive interface, get all input of make tuple node. +std::vector FetchInputsByMakeTuple(const AnfNodePtr &node); // ControlNodeParser is used to parse control nodes, and get the edges between nodes. class ControlNodeParser { @@ -77,6 +87,26 @@ class ControlNodeParser { // multiple branch outputs, there will be multiple output nodes. std::vector FetchAllBranchOutputs(const FuncGraphPtr &func_graph); + // Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node + // which device address belongs, so switch actor need to get all the possible nodes. + std::vector FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output); + + // Get the device context corresponding to the value node. + DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node); + + // Get the branch id corresponding to funcgraph. + int GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph); + + // Get the number of calls to funcgraph + size_t GetCallNumByFuncGraph(const FuncGraphPtr &func_graph); + + // Get all possible input nodes of the output node. When the gather actor is the output, it need to send the node + // which device address belongs, so gather actor need to get all the possible nodes. + std::vector GetBackendInputByParameter(const AnfNodePtr ¶meter); + + // Check whether there is a call node in the front input nodes of the kernel graph. + bool IsCallInputKernelGraph(const KernelGraphPtr &graph); + private: friend class GraphScheduler; @@ -84,47 +114,30 @@ class ControlNodeParser { // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for // them separately during initialization. // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to. - void FetchFrontValueNode(const std::vector &graphs, + void FetchFrontValueNode(const std::vector &control_nodes, const std::vector &graphs, const std::vector &device_contexts); - + // Create branch id for all subgraphs in the control flow. + void CreateBranchIDForFuncGraph(const std::vector &control_nodes); // Find all value nodes in the switch recursively. - void FetchValueNodeInSwitchNode(const AnfNodePtr &switch_node, std::vector *value_nodes); - + void FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector *value_nodes); // Fetch all the relationships between front parameters and backend parameters.The front parameters // include two parts: // 1. The parameter from kernel graph. // 2. The parameter from control nodes. - void FetchFrontToBackendParameterMap(const std::vector &graphs, - const std::vector &device_contexts, - const std::vector &control_nodes); - + void FetchFrontToBackendParameter(const std::vector &graphs, + const std::vector &device_contexts, + const std::vector &control_nodes); + // Get the relationship between the front and backend of the executable kernel in all kernel graphs. + void FetchFrontToBackendKernel(const std::vector &graphs, + const std::vector &device_contexts); // Get inputs of control node which come from the host actor. These inputs generally come from the partial // nodes and call nodes of the root funcgraph. std::vector FetchControlNodeParameter(const std::vector &control_nodes); - // Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node, // and the input of the call node is the input parameter of the corresponding funcgraph. - void FetchFuncGraphToParameterMap(const std::vector &control_nodes); - + void FetchFuncGraphToParameter(const std::vector &control_nodes); // Get all the front weight parameters related to the weight in the host parameter. - void FetchHostParameterToWeightMap(const std::vector &control_nodes); - - // Get the pos input of call node to funcgraph. - AnfNodePtr GetCallNodeInputByPos(const AnfNodePtr &call_node, const FuncGraphPtr &func_graph, const size_t pos); - - // Find the output of the funcgraph, if the output is a call node, return the output of the funcgraph - // called by the call node. - std::vector FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std::vector *call_nodes); - - // Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding - // backend parameter, then recursively find the backend parameters of other front parameters corresponding to the - // front_node. - std::pair FetchBackendNodeByFrontNode( - const AnfNodePtr &front_node, - const std::unordered_map> &front_to_front_parameter, - const std::unordered_map> &front_to_backend_parameter, - std::set *invalid_node); - + void FetchHostParameterToWeight(const std::vector &control_nodes); // The relationship between front parameters indicates that the parameter is directly used as the input of the // funcgraph. There are two situations: // 1. The parameter is used as the input of the call node, @@ -132,12 +145,44 @@ class ControlNodeParser { // subsequent call node. void FetchFrontToFrontParameterMap(const std::vector &control_nodes, std::unordered_map> *front_to_front_parameter); + // Get the number of calls to all subgraphs in the whole funcgraph. + void FetchFuncGraphCallNum(const std::vector &control_nodes); + // Get all the kernel graphs where the input node has a call node. + void FetchCallInputKernelGraph(const std::vector &graphs, + const std::vector &device_contexts); + // Get the relationship of all real and formal parameters in the whole funcgraph. + void FetchBackendInputNode(); + + // Get all possible input node of real parameter. + void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter); + // Recursive interface, get all Backend node by front_output. + std::vector FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, + std::set *call_nodes, + std::set *switch_nodes); // The front to backend parameters is used to build and link the host data source actor in the control flow scenario. FrontToBackendNodeWithContext front_to_backend_parameters_; + + // The relationship between all real parameters and formal parameters in the entire func_graph. + // In control flow, the control actor will be the output actor. Since the actor needs to send the node to the output + // actor, it is necessary to save all the real parameters corresponding to the formal parameters in the control actor. + // When the control actor receives the device address, it can find the corresponding input node. + std::unordered_map> formal_to_real_parameters_; + + // Relationship between the front and backend of the executable kernel in all kernel graphs. + FrontToBackendNodeWithContext front_to_backend_kernels_; + // The funcgraph to parameters map records the input parameters of funcgraph and is used to initialize // the input node of gather. FuncGraphToParameter func_graph_to_parameters_; + + // Branch id of funcgraph. + // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to + // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch + // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id + // to its output switch actor. + std::unordered_map func_graph_to_branch_id_; + // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph. // When initializing the weights, all related weights need to be recorded as the same device tensor. HostParameterToWeight host_parameter_to_weights_; @@ -148,6 +193,13 @@ class ControlNodeParser { std::vector front_output_nodes_; // Parameters of control node which come from the host actor. std::vector control_node_parameters_; + // The number of calls to func_graph. + std::unordered_map func_graph_to_call_num_; + // The kernel graph of call exists in the front-end input node. + std::unordered_map call_input_kernel_graphs_; + // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor + // is created for kernel graph which has a call input. + std::vector root_graph_parameters_; }; using ControlNodeParserPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 36e2371444e..30b3c311c60 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -1058,10 +1058,7 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler auto gather_actor = std::make_shared(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID()); - gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.origin_parameters_order_, - graph_compiler_info.control_node_parser_->front_to_backend_parameters_, - graph_compiler_info.control_node_parser_->func_graph_to_parameters_, - front_to_backend_kernel); + gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_); InsertActor(gather_actor.get()); gather_actors.emplace_back(gather_actor); }