From 8d1b5781e7a62178b3b197f5c2c7f49245b0a1b3 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Mon, 28 Jun 2021 21:05:15 +0800 Subject: [PATCH] Fetch backend node for switch actor. --- .../runtime/framework/actor/switch_actor.cc | 9 +- .../runtime/framework/actor/switch_actor.h | 3 +- .../runtime/framework/control_node_parser.cc | 225 +++++++++++++----- .../runtime/framework/control_node_parser.h | 43 ++-- .../runtime/framework/graph_scheduler.cc | 9 +- 5 files changed, 204 insertions(+), 85 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc index ec41691118e..cc50c883cfb 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc @@ -199,7 +199,8 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b const auto &node = node_with_index.first; // Add weight and value node. - if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && HasAbstractRef(node)) || node->isa()) { + if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa() && HasAbstractRef(node)) || + node->isa()) { const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index); if (iter != input_nodes_.end()) { branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin()); @@ -212,7 +213,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b } // Output of updatestate node is U, need to be skipped. - if (HasAbstractRef(node)) { + if (node->isa() && HasAbstractRef(node)) { return; } @@ -447,7 +448,7 @@ void SwitchActor::SendMemoryFreeReq(OpContext *context) { void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) { for (size_t i = 0; i < input_nodes_.size(); ++i) { const auto &input_node = input_nodes_[i].first; - if (!HasAbstractRef(input_node)) { + if (!(input_node->isa() && HasAbstractRef(input_node))) { backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node); continue; } @@ -456,7 +457,7 @@ void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) { if (backend_weight == nullptr) { MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node); } - backend_parameters_[i].push_back({backend_weight, 0}); + backend_parameters_[i].insert({backend_weight, 0}); } } } // namespace runtime diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.h b/mindspore/ccsrc/runtime/framework/actor/switch_actor.h index 99a9ce7cfd7..7f562e99010 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -138,7 +139,7 @@ class SwitchActor : public SwitchActorBase { // When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor, // so all the nodes that may send the device tensor to switch actor are recorded. - std::vector> backend_parameters_; + std::vector> backend_parameters_; std::vector> branch_total_inputs_; std::vector branch_func_graph_; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 05628c13fbf..f537e5a7c85 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -17,12 +17,14 @@ #include "runtime/framework/control_node_parser.h" #include "runtime/framework/actor/switch_actor.h" #include "runtime/framework/actor/gather_actor.h" +#include "abstract/utils.h" #include "ir/tensor.h" namespace mindspore { namespace runtime { namespace { +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; // Fetch all the weight parameters related to node. It runs like this: // if we have a map like {{a, {b, c}}, {b, {d, e}}}, final we will get {{a, {b, c, d, e}}, {b, {c, d}}}. void FetchWeightbyHostParameter(const AnfNodePtr &node, std::vector *dest_nodes, @@ -167,6 +169,30 @@ void CreateDeviceTensorForValueNode(const AnfNodePtr &front_node, const AnfNodeP AnfAlgo::SetOutputAddr(address, 0, front_node.get()); } +// Create a device tensor for front parameter. +// When the condition input of the switch and switchlayer or the output of a subgraph is a parameter, there is no +// corresponding backend node for this parameter, so a device tensor needs to be created for it. +void CreateDeviceTensorForFrontParameter(const AnfNodePtr &node, const DeviceContext *device_context) { + MS_EXCEPTION_IF_NULL(device_context); + + TypeId type_id = AnfAlgo::GetOutputInferDataType(node, 0); + + if (node->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + std::shared_ptr builder = std::make_shared(); + builder->SetOutputsFormat({kOpFormat_DEFAULT}); + builder->SetOutputsDeviceType({type_id}); + kernel_info->set_select_kernel_build_info(builder->Build()); + node->set_kernel_info(kernel_info); + } + size_t size = AnfAlgo::GetOutputTensorMemSize(node, 0); + + // Create device tensor. + device::DeviceAddressPtr address = device_context->CreateDeviceAddress(nullptr, size, kOpFormat_DEFAULT, type_id); + MS_EXCEPTION_IF_NULL(address); + AnfAlgo::SetOutputAddr(address, 0, node.get()); +} + // Find the corresponding backend parameter for the front_node. If the front_node does not have the corresponding // backend parameter, then recursively find the backend parameters of other front parameters corresponding to the // front_node. @@ -547,41 +573,60 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) { void ControlNodeParser::Parse(const std::vector &control_nodes, const std::vector &graphs, const std::vector &device_contexts, const FuncGraphPtr &root_graph) { + if (graphs.size() != device_contexts.size()) { + MS_LOG(EXCEPTION) << "Graph num is not equal to device context, graph:" << graphs.size() + << " device context num:" << device_contexts.size(); + } + if (graphs.empty()) { + return; + } + root_func_graph_ = root_graph; root_graph_parameters_ = root_graph->parameters(); CreateBranchIDForFuncGraph(control_nodes); - FetchFrontToBackendParameter(graphs, device_contexts, control_nodes); + RealToFormalNode real_to_formal_front_parameters; + FetchFrontToFrontParameter(control_nodes, &real_to_formal_front_parameters); + + RealToFormalNode formal_to_real_front_parameters; + for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) { + for (const auto formal_parameter : real_to_formal_front_parameter.second) { + formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first); + } + } + + FetchFrontToBackendParameter(graphs, device_contexts, control_nodes, real_to_formal_front_parameters, + formal_to_real_front_parameters); FetchFuncGraphToParameter(control_nodes); - FetchHostParameterToWeight(control_nodes); + FetchHostParameterToWeight(real_to_formal_front_parameters); FetchFrontValueNode(control_nodes, graphs, device_contexts); FetchFrontToBackendKernel(graphs, device_contexts); - control_node_parameters_ = FetchControlNodeParameter(control_nodes); + control_node_parameters_ = FetchControlNodeParameter(control_nodes, device_contexts[0]); FetchFuncGraphCallNum(control_nodes); FetchCallInputKernelGraph(graphs, device_contexts); - FetchBackendInputNode(); - - front_output_nodes_ = FetchAllBranchOutputs(root_graph); + FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters); } std::vector ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr ¶meter) { return formal_to_real_parameters_[parameter]; } -std::vector ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) { +std::set ControlNodeParser::FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output) { std::set call_nodes; std::set switch_nodes; - return FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes); + std::set results; + FetchBackendOutputByFrontOutput(front_output, &call_nodes, &switch_nodes, &results); + return results; } int ControlNodeParser::GetBranchIDByFuncGraph(const FuncGraphPtr &func_graph) { @@ -762,7 +807,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr } } -void ControlNodeParser::FetchFrontToFrontParameterMap( +void ControlNodeParser::FetchFrontToFrontParameter( const std::vector &control_nodes, std::unordered_map> *front_to_front_parameter) { // Function used to collect the input of call node. @@ -844,7 +889,8 @@ void ControlNodeParser::FetchFrontToFrontParameterMap( } } -std::vector ControlNodeParser::FetchControlNodeParameter(const std::vector &control_nodes) { +std::vector ControlNodeParser::FetchControlNodeParameter(const std::vector &control_nodes, + DeviceContext *device_context) { std::vector parameters; for (const auto &control_node : control_nodes) { @@ -864,6 +910,29 @@ std::vector ControlNodeParser::FetchControlNodeParameter(const std:: parameters.emplace_back(inputs[i]); } } + } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch)) { + if (inputs.size() != kSwitchInputNum) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node); + } + if (inputs[kSwitchCondPos]->isa()) { + parameters.emplace_back(inputs[kSwitchCondPos]); + } + } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { + if (inputs.size() != kSwitchLayerInputNum) { + MS_LOG(EXCEPTION) << "Invalid switch node:" << AnfAlgo::GetNodeDebugString(control_node); + } + if (inputs[kSwitchLayerCondPos]->isa()) { + parameters.emplace_back(inputs[kSwitchLayerCondPos]); + } + } + } + + for (const auto ¶meter : parameters) { + auto backend_iter = front_to_backend_parameters_.find(parameter); + if (backend_iter == front_to_backend_parameters_.end()) { + CreateDeviceTensorForFrontParameter(parameter, device_context); + front_to_backend_parameters_[parameter] = {parameter, device_context}; + front_parameters_.push_back({parameter, device_context}); } } @@ -896,7 +965,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vectorparameters(); + const auto inputs = graph->input_nodes(); for (const auto &input : inputs) { const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input); if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) { @@ -983,7 +1052,8 @@ std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) continue; } - if (HasAbstractRef(AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first) || + const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first; + if ((real_front_node->isa() && HasAbstractRef(real_front_node)) || HasAbstractMonad(front_node_with_index.first)) { continue; } @@ -1002,7 +1072,9 @@ std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) void ControlNodeParser::FetchFrontToBackendParameter(const std::vector &graphs, const std::vector &device_contexts, - const std::vector &control_nodes) { + const std::vector &control_nodes, + const RealToFormalNode &real_to_formal_front_parameters, + const RealToFormalNode &formal_to_real_front_parameters) { if (graphs.size() != device_contexts.size()) { MS_LOG(EXCEPTION) << "Graph num is not equal to device context num."; } @@ -1026,7 +1098,7 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vectorparameters()) { + for (const auto ¶meter : graph->input_nodes()) { const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(parameter); if (internal_front_node.first != nullptr) { @@ -1043,21 +1115,6 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector> real_to_formal_front_parameters; - FetchFrontToFrontParameterMap(control_nodes, &real_to_formal_front_parameters); - - std::unordered_map> formal_to_real_front_parameters; - for (const auto real_to_formal_front_parameter : real_to_formal_front_parameters) { - for (const auto formal_parameter : real_to_formal_front_parameter.second) { - formal_to_real_front_parameters[formal_parameter].emplace_back(real_to_formal_front_parameter.first); - } - } - for (const auto &front_pair : real_to_formal_front_parameters) { std::set invalid_node; const auto &backend_node = @@ -1071,14 +1128,10 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector &control_nodes) { - std::unordered_map> front_to_front_parameter; - - FetchFrontToFrontParameterMap(control_nodes, &front_to_front_parameter); - - for (const auto &pair : front_to_front_parameter) { +void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front_to_front_parameters) { + for (const auto &pair : front_to_front_parameters) { std::vector dest_nodes; - FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameter); + FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters); host_parameter_to_weights_[pair.first] = dest_nodes; } } @@ -1136,18 +1189,20 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, - std::set *call_nodes, - std::set *switch_nodes) { - std::vector backend_outputs; +void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, + std::set *call_nodes, + std::set *switch_nodes, + std::set *results) { if (front_output->isa()) { - backend_outputs.push_back({front_output, 0}); + (*results).insert({front_output, 0}); } else if (front_output->isa()) { // Output is a parameter. const auto iter = formal_to_real_parameters_.find(front_output); if (iter != formal_to_real_parameters_.end()) { - backend_outputs.insert(backend_outputs.end(), iter->second.begin(), iter->second.end()); + for (const auto &node : iter->second) { + (*results).insert(node); + } } else { MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output); } @@ -1156,16 +1211,14 @@ std::vector ControlNodeParser::FetchBackendOutputByFrontOutput( const auto &switch_outputs = FetchOutputBySwitchNode(front_output, call_nodes, switch_nodes); for (const auto &switch_output : switch_outputs) { - const auto outputs = FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes); - backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end()); + FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results); } } else if (IsCallNode(front_output)) { // Output is a call. const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes); for (const auto &call_output : call_outputs) { - const auto outputs = FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes); - backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end()); + FetchBackendOutputByFrontOutput(call_output, call_nodes, switch_nodes, results); } } else if (AnfAlgo::CheckPrimitiveType(front_output, prim::kPrimMakeTuple)) { // Output is a make tuple. @@ -1173,8 +1226,7 @@ std::vector ControlNodeParser::FetchBackendOutputByFrontOutput( const auto &inputs = cnode->inputs(); for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { - const auto outputs = FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes); - backend_outputs.insert(backend_outputs.end(), outputs.begin(), outputs.end()); + FetchBackendOutputByFrontOutput(inputs[i], call_nodes, switch_nodes, results); } } else if (front_output->isa()) { // Output is a kernel. @@ -1182,19 +1234,18 @@ std::vector ControlNodeParser::FetchBackendOutputByFrontOutput( if (iter != front_to_backend_kernels_.end()) { const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(iter->second.first, 0); - backend_outputs.emplace_back(output_with_index); + (*results).insert(output_with_index); } else { MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output); } } else { MS_LOG(EXCEPTION) << "Invalid front node:" << AnfAlgo::GetNodeDebugString(front_output); } - - return backend_outputs; } -void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, - const AnfNodePtr &formal_parameter) { +void ControlNodeParser::FetchBackendInputNodebyFrontNode( + const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter, + const FrontToBackendNodeWithContext &front_to_backend_parameters) { if (real_parameter->isa()) { // Input node is a parameter from host data source actor. std::set invalid_inputs; @@ -1205,8 +1256,8 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_ const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0); if (node_with_index.first->isa()) { - const auto &iter = front_to_backend_parameters_.find(real_parameter); - if (iter == front_to_backend_parameters_.end()) { + const auto &iter = front_to_backend_parameters.find(real_parameter); + if (iter == front_to_backend_parameters.end()) { MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); continue; } @@ -1224,7 +1275,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_ } else if (IsCallNode(real_parameter)) { const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); for (const auto func_graph : func_graphs) { - FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter); + FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters); } } else { // Input node is a cnode. @@ -1237,7 +1288,56 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_ } } -void ControlNodeParser::FetchBackendInputNode() { +void ControlNodeParser::FetchBackendParameterNode(const std::vector &graphs, + const std::vector &device_contexts, + const RealToFormalNode &real_to_formal_front_parameters, + const RealToFormalNode &formal_to_real_front_parameters, + FrontToBackendNodeWithContext *front_to_backend_parameters) { + for (size_t i = 0; i < graphs.size(); ++i) { + const auto &graph = graphs[i]; + const auto &device_context = device_contexts[i]; + if (graph->GetFuncGraph() != root_func_graph_) { + continue; + } + for (const auto ¶meter : graph->input_nodes()) { + auto front_node = graph->GetFrontAnfByBackendAnf(parameter); + + if (front_node != nullptr && front_node->isa() && + (*front_to_backend_parameters).find(front_node) == (*front_to_backend_parameters).end()) { + (*front_to_backend_parameters)[front_node] = {parameter, device_context}; + } + } + } + for (const auto &control_node_parameter : control_node_parameters_) { + const auto &iter = front_to_backend_parameters_.find(control_node_parameter); + if (iter == front_to_backend_parameters_.end()) { + MS_LOG(EXCEPTION) << "Cannot find backend node for control node parameter:" + << AnfAlgo::GetNodeDebugString(control_node_parameter); + } + (*front_to_backend_parameters)[control_node_parameter] = iter->second; + } + + for (const auto &front_pair : formal_to_real_front_parameters) { + std::set invalid_node; + const auto &backend_node = + FetchBackendNodeByFrontNode(front_pair.first, real_to_formal_front_parameters, formal_to_real_front_parameters, + (*front_to_backend_parameters), &invalid_node); + if (backend_node.first != nullptr) { + if ((*front_to_backend_parameters).find(front_pair.first) == (*front_to_backend_parameters).end()) { + (*front_to_backend_parameters)[front_pair.first] = backend_node; + } + } + } +} + +void ControlNodeParser::FetchBackendInputNode(const std::vector &graphs, + const std::vector &device_contexts, + const RealToFormalNode &real_to_formal_front_parameters, + const RealToFormalNode &formal_to_real_front_parameters) { + FrontToBackendNodeWithContext front_to_backend_parameters; + FetchBackendParameterNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters, + &front_to_backend_parameters); + for (const auto &func_graph_to_parameters : func_graph_to_parameters_) { const auto &func_graph = func_graph_to_parameters.first; std::vector graph_inputs; @@ -1259,14 +1359,15 @@ void ControlNodeParser::FetchBackendInputNode() { } for (size_t i = 0; i < parameters.size(); ++i) { - FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i]); + FetchBackendInputNodebyFrontNode(parameters[i], graph_inputs[i], front_to_backend_parameters); } } } - - for (const auto front_to_backend_parameters : front_to_backend_parameters_) { - formal_to_real_parameters_[front_to_backend_parameters.first].push_back( - {front_to_backend_parameters.second.first, 0}); + for (const auto parameter_pair : front_to_backend_parameters) { + formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); + } + for (const auto parameter_pair : front_to_backend_parameters_) { + formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); } } } // namespace runtime diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 5f824c132a8..79f5b1f2f0e 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -41,6 +41,7 @@ using FrontToBackendNodeWithContext = std::unordered_map>>; using HostParameterToWeight = std::unordered_map>; using NodeWithDeviceContext = std::vector>; +using RealToFormalNode = std::unordered_map>; // Check whether node is a call node, there are two types of call nodes: // 1. First input of node is a cnode. @@ -89,7 +90,7 @@ class ControlNodeParser { // Get all possible input nodes of the output node. When the switch actor is the output, it need to send the node // which device address belongs, so switch actor need to get all the possible nodes. - std::vector FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output); + std::set FetchBackendInputNodeByFrontNode(const AnfNodePtr &front_output); // Get the device context corresponding to the value node. DeviceContext *GetFrontValueNodeDeviceContext(const AnfNodePtr &value_node); @@ -135,39 +136,50 @@ class ControlNodeParser { // 2. The parameter from control nodes. void FetchFrontToBackendParameter(const std::vector &graphs, const std::vector &device_contexts, - const std::vector &control_nodes); + const std::vector &control_nodes, + const RealToFormalNode &real_to_formal_front_parameters, + const RealToFormalNode &formal_to_real_front_parameters); // Get the relationship between the front and backend of the executable kernel in all kernel graphs. void FetchFrontToBackendKernel(const std::vector &graphs, const std::vector &device_contexts); // Get inputs of control node which come from the host actor. These inputs generally come from the partial // nodes and call nodes of the root funcgraph. - std::vector FetchControlNodeParameter(const std::vector &control_nodes); + std::vector FetchControlNodeParameter(const std::vector &control_nodes, + DeviceContext *device_context); // Get all the input parameters of funcgraph. The call of funcgraph is realized through the call node, // and the input of the call node is the input parameter of the corresponding funcgraph. void FetchFuncGraphToParameter(const std::vector &control_nodes); // Get all the front weight parameters related to the weight in the host parameter. - void FetchHostParameterToWeight(const std::vector &control_nodes); + void FetchHostParameterToWeight(const RealToFormalNode &real_to_formal_front_parameters); // The relationship between front parameters indicates that the parameter is directly used as the input of the // funcgraph. There are two situations: // 1. The parameter is used as the input of the call node, // 2. The parameter is used as the input of the partial and will be input to the funcgraph of the partial in the // subsequent call node. - void FetchFrontToFrontParameterMap(const std::vector &control_nodes, - std::unordered_map> *front_to_front_parameter); + void FetchFrontToFrontParameter(const std::vector &control_nodes, + std::unordered_map> *front_to_front_parameter); // Get the number of calls to all subgraphs in the whole funcgraph. void FetchFuncGraphCallNum(const std::vector &control_nodes); // Get all the kernel graphs where the input node has a call node. void FetchCallInputKernelGraph(const std::vector &graphs, const std::vector &device_contexts); + // Get the relationship of all real and formal nodes in the whole funcgraph. + void FetchBackendInputNode(const std::vector &graphs, + const std::vector &device_contexts, + const RealToFormalNode &real_to_formal_front_parameters, + const RealToFormalNode &formal_to_real_front_parameters); // Get the relationship of all real and formal parameters in the whole funcgraph. - void FetchBackendInputNode(); - + void FetchBackendParameterNode(const std::vector &graphs, + const std::vector &device_contexts, + const RealToFormalNode &real_to_formal_front_parameters, + const RealToFormalNode &formal_to_real_front_parameters, + FrontToBackendNodeWithContext *front_to_backend_parameters); // Get all possible input node of real parameter. - void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter); + void FetchBackendInputNodebyFrontNode(const AnfNodePtr &real_parameter, const AnfNodePtr &formal_parameter, + const FrontToBackendNodeWithContext &front_to_backend_parameters); // Recursive interface, get all Backend node by front_output. - std::vector FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, - std::set *call_nodes, - std::set *switch_nodes); + void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set *call_nodes, + std::set *switch_nodes, std::set *results); // The front to backend parameters is used to build and link the host data source actor in the control flow scenario. FrontToBackendNodeWithContext front_to_backend_parameters_; @@ -195,11 +207,14 @@ class ControlNodeParser { // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph. // When initializing the weights, all related weights need to be recorded as the same device tensor. HostParameterToWeight host_parameter_to_weights_; + // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the // input of the control node. NodeWithDeviceContext front_value_nodes_; - // The front output_node is used to link the output actor in multi-branch output scenario. - std::vector front_output_nodes_; + // The front value node saves all parameters that are not in the kernel graph. These nodes are generally the + // output of subgraph, or the switch condition node. + NodeWithDeviceContext front_parameters_; + // Parameters of control node which come from the host actor. std::vector control_node_parameters_; // The number of calls to func_graph. diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 0a6fe2ca7c6..6f7815a8a42 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -1107,7 +1107,7 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler // Collect the parameters. std::vector parameters; for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - if (HasAbstractMonad(inputs[i]) || HasAbstractRef(inputs[i])) { + if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa() && HasAbstractRef(inputs[i]))) { continue; } parameters.emplace_back(inputs[i]); @@ -2115,8 +2115,9 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c const auto &backend_node = backend_iter->second.first; auto iter = from_actor->data_node_position_map_.find(input_node); if (iter == from_actor->data_node_position_map_.end()) { - MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, node:" - << AnfAlgo::GetNodeDebugString(backend_node); + MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:" + << AnfAlgo::GetNodeDebugString(backend_node) + << " front node:" << AnfAlgo::GetNodeDebugString(input_node); } auto op_arrow = std::make_shared(iter->second, to_actor->GetAID(), to_index); @@ -2133,7 +2134,7 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_ const auto &inputs = actor->input_nodes_; for (size_t i = 0; i < inputs.size(); ++i) { auto input = inputs[i]; - if (input.first->isa() || HasAbstractRef(input.first)) { + if (input.first->isa() || (input.first->isa() && HasAbstractRef(input.first))) { continue; }