diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc index a875505e71f..efa0bc6bc61 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc @@ -42,10 +42,10 @@ void GatherActor::Init() { } } -size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const { +size_t GatherActor::FetchDataNodePosition(const KernelWithIndex &data_node) const { const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node); if (iter == data_nodes_.end()) { - MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node) + MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node.first) << " index:" << data_node.second << " is not exist in gather actor:" << GetAID(); } return iter - data_nodes_.begin(); @@ -114,7 +114,7 @@ void GatherActor::SendOutput(OpContext *context) const { for (const auto &result_arrow : output_result_arrows_) { MS_EXCEPTION_IF_NULL(result_arrow); size_t from_index = result_arrow->from_output_index_; - const auto &front_node = data_nodes_[from_index]; + const auto &front_node = data_nodes_[from_index].first; for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) { if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() == input_device_tensors_[from_index]) { diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h index 578093aa25b..de431eee227 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h @@ -47,7 +47,7 @@ constexpr size_t kReturnInputPos = 1; // collected at the entrance of the kernel graph. class GatherActor : public OpActor { public: - GatherActor(const std::string &name, const std::vector ¶meters, const bool need_branch_id_input, + GatherActor(const std::string &name, const std::vector ¶meters, const bool need_branch_id_input, const AID switch_aid, const AID gather_aid, const int branch_id) : OpActor(name), data_nodes_(parameters), @@ -60,7 +60,7 @@ class GatherActor : public OpActor { ~GatherActor() override = default; // Get the index of the parameter, the data_node needs to be the front node. - size_t FetchDataNodePosition(const AnfNodePtr &data_node) const; + size_t FetchDataNodePosition(const KernelWithIndex &data_node) const; // The gather actor run when receive the input data. void RunOpData(OpData *input_data, OpContext *context) override; @@ -107,7 +107,7 @@ class GatherActor : public OpActor { std::vector output_branch_arrows_; // Parameters of sub funcgraph, which is the front node. - std::vector data_nodes_; + std::vector data_nodes_; std::vector device_contexts_; // Pair points to the dependent device tensor store, anfNode is the key of the device tensor store. std::vector> device_tensor_store_keys_; diff --git a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc index 97d21a0e629..923cc44df33 100644 --- a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc @@ -78,7 +78,9 @@ void KernelActor::RunOpData(OpData *input_data, OpContextsequential_num_; input_op_datas_[sequential_num].emplace_back(input_data); if (input_data->data_ == nullptr) { - MS_LOG(EXCEPTION) << "Input data of actor:" << GetAID() << " num:" << input_data->index_ << " is empty"; + std::string error_info = + "Input data of actor:" + GetAID().Name() + " num:" + std::to_string(input_data->index_) + " is empty"; + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } // When all the inputs are collected, then allocate memory and callback launch. if (CheckLaunchCondition(context)) { diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc index d4e6a0ad0f5..3359537cc8d 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc @@ -72,20 +72,20 @@ void SwitchActor::CollectBranchId(const int branch_id, OpContext * input_branch_ids_[sequential_num].push(branch_id); } -void SwitchActor::Initialize(const ControlNodeParserPtr &parser) { +void SwitchActor::ParseInput(const ControlNodeParserPtr &parser) { std::vector inputs = node_->inputs(); if (IsPrimitive(inputs[0], prim::kPrimSwitch)) { - InitSwitch(); + ParseSwitchInput(); } else if (IsPrimitive(inputs[0], prim::kPrimReturn)) { - InitReturn(parser); + ParseReturnInput(parser); } else { - InitSwitchLayer(); + ParseSwitchLayerInput(); } backend_parameters_.resize(input_nodes_.size()); } -void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) { +void SwitchActor::ParsePartialInput(const AnfNodePtr &node, const size_t branch_id) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { CNodePtr cnode = node->cast(); @@ -93,20 +93,50 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) { // [0] ValueNode kPartial. // [1] ValueNode. // [2..] Inputs. - const auto &node_inputs = cnode->inputs(); - if (node_inputs.size() <= kPartialFuncGraphPos) { + auto partial_inputs = cnode->inputs(); + if (partial_inputs.size() <= kPartialFuncGraphPos) { MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(cnode); } - const auto &func_graph = GetValueNode(node_inputs[kPartialFuncGraphPos]); + auto func_graph = GetValueNode(partial_inputs[kPartialFuncGraphPos]); if (func_graph->output()->isa()) { AddInput(func_graph->output(), branch_id); + return; + } else if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) { + // If the funcgraph called by the partial returns a partial node, the switch actor should call the funcgraph + // of the sub partial. Similarly, the input node should also be the input of the sub partial. + is_mulit_call_ = true; + CNodePtr sub_partial = func_graph->output()->cast(); + const auto &sub_partial_inputs = sub_partial->inputs(); + if (sub_partial_inputs.size() <= kPartialFuncGraphPos) { + MS_LOG(EXCEPTION) << "Invalid Partial node:" << AnfAlgo::GetNodeDebugString(sub_partial); + } + const auto &sub_func_graph = GetValueNode(sub_partial_inputs[kPartialFuncGraphPos]); + + if (sub_func_graph->output()->isa()) { + AddInput(sub_func_graph->output(), branch_id); + return; + } + + branch_func_graph_[branch_id] = sub_func_graph; + const auto &sub_parameters = func_graph->parameters(); + + // Record the input that comes with the sub partial node. + for (size_t i = kPartialInputStartPos; i < sub_partial_inputs.size(); ++i) { + const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[i], 0).first; + const auto &iter = find(sub_parameters.begin(), sub_parameters.end(), real_partial_input); + if ((iter != sub_parameters.end()) && + ((iter - sub_parameters.begin()) < SizeToInt(partial_inputs.size() - kPartialInputStartPos))) { + AddInput(partial_inputs[iter - sub_parameters.begin() + kPartialInputStartPos], branch_id); + } + } + return; } branch_func_graph_[branch_id] = func_graph; - for (size_t j = kPartialInputStartPos; j < node_inputs.size(); ++j) { - AddInput(node_inputs[j], branch_id); + for (size_t j = kPartialInputStartPos; j < partial_inputs.size(); ++j) { + AddInput(partial_inputs[j], branch_id); } } else { AddInput(node, branch_id); @@ -122,15 +152,21 @@ void SwitchActor::InitVectorSize(const size_t num) { output_branch_branch_arrows_.resize(num); } -void SwitchActor::InitReturn(const ControlNodeParserPtr &parser) { +void SwitchActor::ParseReturnInput(const ControlNodeParserPtr &parser) { const auto &func_graph = node_->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); const auto &call_num = parser->GetCallNumByFuncGraph(func_graph); InitVectorSize(call_num); + + // If the return is a partial node or funcgraph, this subgraph will not be initialized and no input is required. + if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial) || + (func_graph->output()->isa() && IsValueNode(func_graph->output()))) { + return; + } AddCommonInput(func_graph->output()); } -void SwitchActor::InitSwitch() { +void SwitchActor::ParseSwitchInput() { // The inputs of the switch node: // [0] ValueNode kSwitch. // [1] switch condition. @@ -147,11 +183,11 @@ void SwitchActor::InitSwitch() { input_nodes_.push_back(cond_node); input_datas_num_++; // Init the two branches of switch node. - InitPartial(inputs[kSwitchFalseBranchPos], static_cast(false)); - InitPartial(inputs[kSwitchTrueBranchPos], static_cast(true)); + ParsePartialInput(inputs[kSwitchFalseBranchPos], static_cast(false)); + ParsePartialInput(inputs[kSwitchTrueBranchPos], static_cast(true)); } -void SwitchActor::InitSwitchLayer() { +void SwitchActor::ParseSwitchLayerInput() { // The inputs of the switch node: // [0] ValueNode kSwitchLayer. // [1] switchLayer index. @@ -170,11 +206,30 @@ void SwitchActor::InitSwitchLayer() { InitVectorSize(branch_nodes.size() - 1); // Parse all branches. - for (size_t i = 1; i < branch_nodes.size(); ++i) { + for (size_t i = kMakeTupleInputStartPos; i < branch_nodes.size(); ++i) { if (AnfAlgo::CheckPrimitiveType(branch_nodes[i], prim::kPrimPartial)) { - InitPartial(branch_nodes[i], i - 1); + ParsePartialInput(branch_nodes[i], i - kMakeTupleInputStartPos); } else if (branch_nodes[i]->isa()) { - branch_func_graph_[i - 1] = GetValueNode(branch_nodes[i]); + const auto &func_graph = GetValueNode(branch_nodes[i]); + const auto output = func_graph->output(); + + // The switch layer node has a second-order call connected to call. When the called funcgraph returns a partial + // node or funcgraph, the switch actor needs to call the funcgraph directly. + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) { + is_mulit_call_ = true; + branch_func_graph_[i - kMakeTupleInputStartPos] = + GetValueNode(output->cast()->input(kPartialFuncGraphPos)); + } else if (output->isa() && IsValueNode(output)) { + is_mulit_call_ = true; + const auto &sub_func_graph = GetValueNode(output); + if (sub_func_graph->output()->isa()) { + AddInput(sub_func_graph->output(), i - kMakeTupleInputStartPos); + continue; + } + branch_func_graph_[i - kMakeTupleInputStartPos] = GetValueNode(output); + } else { + branch_func_graph_[i - kMakeTupleInputStartPos] = func_graph; + } } } } @@ -198,8 +253,12 @@ size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const { void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) { const auto &node = node_with_index.first; - // Add weight and value node. - if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && node->isa() && HasAbstractRef(node)) || + // The value node and weight node need to be placed in the device store. The switch actor has three inputs: + // 1) The input of the switch is the value node. + // 2) There is a weight node or value node in the return of the sub funcgraph. + // 3) When the switch actor is a second-order call, it does not distinguish between weight and parameter. + if (((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) || is_mulit_call_) && node->isa() && + HasAbstractRef(node)) || node->isa()) { const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index); if (iter != input_nodes_.end()) { @@ -243,7 +302,6 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) { } else if (IsCallNode(real_input.first)) { std::vector call_nodes; const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes); - if (call_output_num <= 0) { MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first); } @@ -259,25 +317,26 @@ size_t SwitchActor::GetIndex(OpContext *context) { if (need_branch_id_input_) { if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() || input_branch_ids_[context->sequential_num_].empty()) { - MS_LOG(EXCEPTION) << "Invalid branch id for actor:" << GetAID(); + MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name(); } size_t branch_id = input_branch_ids_[context->sequential_num_].top(); input_branch_ids_[context->sequential_num_].pop(); if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) { - MS_LOG(EXCEPTION) << "Invalid branch id for switch actor:" << GetAID() << " branch id:" << branch_id; + MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() + + " branch id:" + std::to_string(branch_id); } return branch_id_to_index_[branch_id]; } DeviceTensor *device_tensor = input_device_tensors_[0]; if (device_tensor == nullptr) { - MS_LOG(EXCEPTION) << "Index of switch actor is empty:" << GetAID(); + MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name(); } auto inputs = node_->inputs(); TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0); size_t size = abstract::TypeIdSize(type_id); if (size > sizeof(int64_t)) { - MS_LOG(EXCEPTION) << "Index must be Int type."; + MS_LOG(ERROR) << "Index must be Int type."; } int64_t index = 0; @@ -293,7 +352,7 @@ size_t SwitchActor::GetIndex(OpContext *context) { bool cond = (static_cast(static_cast(buf)))[0]; index = static_cast(cond ? 1 : 0); } else { - MS_LOG(EXCEPTION) << "Index must be Int type."; + MS_LOG(ERROR) << "Index must be Int type."; } // SwitchLayer node support negative index range [-size, -1]. @@ -352,7 +411,7 @@ void SwitchActor::FetchInputDeviceTensor(OpContext *context) { DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType()); if (device_tensor == nullptr) { std::string error_info = - GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() + + GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() + ", device type:" + std::to_string(static_cast(device_context_->GetDeviceAddressType())); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } @@ -370,7 +429,8 @@ void SwitchActor::SendOutput(OpContext *context) { MS_EXCEPTION_IF_NULL(context); auto index = GetIndex(context); if (index >= output_branch_arrows_.size()) { - MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index; + std::string error_info = "Switch actor:" + GetAID().Name() + " invalid index:" + std::to_string(index); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } // Must be the execution order: send branch id --> send result --> send data --> send control, avoid the illegal @@ -389,8 +449,10 @@ void SwitchActor::SendOutput(OpContext *context) { auto &result_arrow = output_branch_result_arrow[i]; MS_EXCEPTION_IF_NULL(result_arrow); if (result_arrow->from_output_index_ >= SizeToInt(branch_inputs_pos_[index].size())) { - MS_LOG(EXCEPTION) << "Invalid from index in switch actor, from index:" << result_arrow->from_output_index_ - << " total:" << branch_inputs_pos_[index].size() << " actor:" << GetAID(); + std::string error_info = + "Invalid from index in switch actor, from index:" + std::to_string(result_arrow->from_output_index_) + + " total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_]; @@ -410,10 +472,12 @@ void SwitchActor::SendOutput(OpContext *context) { } } if (!is_send) { - MS_LOG(EXCEPTION) << "Failed to get backend node of switch actor output, actor:" << GetAID() - << " branch:" << index << " index:" << result_arrow->from_output_index_ << " output pos" - << branch_inputs_pos_[index][result_arrow->from_output_index_] << " output index" - << result_arrow->to_input_index_; + std::string error_info = "Failed to get backend node of switch actor output, actor:" + GetAID().Name() + + " branch:" + std::to_string(index) + + " index:" + std::to_string(result_arrow->from_output_index_) + " output pos" + + std::to_string(branch_inputs_pos_[index][result_arrow->from_output_index_]) + + " output index" + std::to_string(result_arrow->to_input_index_); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } } @@ -426,7 +490,6 @@ void SwitchActor::SendOutput(OpContext *context) { MS_EXCEPTION_IF_NULL(data_arrow); MS_EXCEPTION_IF_NULL(data); data->data_ = input_device_tensors_[data_arrow->from_output_index_]; - Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context); } diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.h b/mindspore/ccsrc/runtime/framework/actor/switch_actor.h index 7f562e99010..5d86d8dd4f7 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.h @@ -76,22 +76,23 @@ class SwitchActor : public SwitchActorBase { void RunOpControl(AID *input_control, OpContext *context); // The switch actor run when receive the input branch id. void CollectBranchId(const int branch_id, OpContext *context); - // Initialize the input and output information of the switch actor According to node_. - void Initialize(const ControlNodeParserPtr &parser); + // Parse the input node information of the switch actor according to node_. + void ParseInput(const ControlNodeParserPtr &parser); // Add input for all branches. void AddCommonInput(const AnfNodePtr &node); + void AddSingleInput(const AnfNodePtr &node, size_t branch) { AddInput(node, branch); } // Fetch the input position of the data node. size_t FetchDataNodePosition(const AnfNodePtr &data_node) const; private: friend class GraphScheduler; - void InitPartial(const AnfNodePtr &node, const size_t branch_id); - void InitSwitch(); - void InitSwitchLayer(); + void ParsePartialInput(const AnfNodePtr &node, const size_t branch_id); + void ParseSwitchInput(); + void ParseSwitchLayerInput(); // In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is // initialized with the return node of the subgraph. - void InitReturn(const ControlNodeParserPtr &parser); + void ParseReturnInput(const ControlNodeParserPtr &parser); // Initialize the size of the vector members. void InitVectorSize(const size_t num); // Get index from DeviceTensor. @@ -170,6 +171,11 @@ class SwitchActor : public SwitchActorBase { // The output_data_ corresponds to the output_data_arrows_ one by one. std::vector>> output_data_; + + // Used to indicate that in the control flow, when the input of the call node is a call node, the switch actor + // corresponding to the switch node called by the sub call node. At this time, the funcgraph of the input of + // the switch actor will return to a partial node or funcgraph. + bool is_mulit_call_{false}; }; using SwitchActorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index cbae2bfc5fa..9121b17c072 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace runtime { - +constexpr size_t kSingleCallDepth = 1; namespace { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; // Fetch all the weight parameters related to node. It runs like this: @@ -339,25 +339,23 @@ std::vector FetchOutputByCallNode(const AnfNodePtr &call_node, std:: const auto func_graphs = FetchFuncGraphbyCallNode(call_node); for (const auto func_graph : func_graphs) { - if (func_graph->output()->isa()) { - outputs.push_back(func_graph->output()); - } else { - std::vector sub_call_nodes; - const std::vector graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes); - for (const auto &graph_output : graph_outputs) { - if (graph_output->isa()) { - outputs.push_back(graph_output); - } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) { - const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes); - outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); - } else if (IsCallNode(graph_output)) { - const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes); - outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); - } else if (graph_output->isa()) { - outputs.emplace_back(graph_output); - } else { - MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output); - } + std::vector sub_call_nodes; + const std::vector graph_outputs = FetchFuncGraphOutput(func_graph, &sub_call_nodes); + for (const auto &graph_output : graph_outputs) { + if (graph_output->isa()) { + outputs.push_back(graph_output); + } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) { + const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes); + outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); + } else if (IsCallNode(graph_output)) { + const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes); + outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); + } else if (graph_output->isa()) { + outputs.emplace_back(graph_output); + } else if (graph_output->isa()) { + outputs.push_back(graph_output); + } else { + MS_LOG(EXCEPTION) << "Invalid front output:" << AnfAlgo::GetNodeDebugString(graph_output); } } } @@ -452,6 +450,70 @@ std::vector FetchParameterByControlNode(const std::vectorcast(); + const auto &inputs = cnode->inputs(); + return kSingleCallDepth + FetchCallDepth(inputs[0]); +} + +// Get the final subgraph called by fungraph through the depth of calls. +FuncGraphPtr FetchFuncGraphByCallDepth(const FuncGraphPtr &func_graph, const size_t call_depth) { + if (call_depth <= kSingleCallDepth) { + return func_graph; + } + + const auto &output = func_graph->output(); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) { + const auto &cnode = output->cast(); + const auto &inputs = cnode->inputs(); + if (inputs.size() < kPartialInputStartPos) { + MS_LOG(EXCEPTION) << "Invalid partial node:" << AnfAlgo::GetNodeDebugString(output); + } + const auto &called_func_graph = GetValueNode(inputs[kPartialFuncGraphPos]); + return FetchFuncGraphByCallDepth(called_func_graph, call_depth - kSingleCallDepth); + } else if (output->isa() && IsValueNode(output)) { + return FetchFuncGraphByCallDepth(GetValueNode(output), call_depth - kSingleCallDepth); + } else { + MS_LOG(EXCEPTION) << "Invalid output for call depth:" << call_depth << " funcgraph:" << func_graph->ToString() + << " output node:" << AnfAlgo::GetNodeDebugString(output); + } +} + +// Get funcgraph from node, the interface only accepts partial node and funcgraph value node. +FuncGraphPtr FetchFuncGraphInNode(const auto &node) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { + const auto &func_graph = GetFuncGraphFromPartial(node); + + if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) { + return FetchFuncGraphInNode(func_graph->output()); + } else if (IsValueNode(func_graph->output())) { + // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called. + return FetchFuncGraphInNode(func_graph->output()); + } + + return func_graph; + } else if (IsValueNode(node)) { + const auto &func_graph = GetValueNode(node); + + if (AnfAlgo::CheckPrimitiveType(func_graph->output(), prim::kPrimPartial)) { + // When the output of funcgraph is a funcgraph, it needs to return the funcgraph that is finally called. + return FetchFuncGraphInNode(func_graph->output()); + } else if (IsValueNode(func_graph->output())) { + // When the output of funcgraph is a partial node, it needs to return the funcgraph that is finally called. + return FetchFuncGraphInNode(func_graph->output()); + } + + return func_graph; + } + + return nullptr; +} } // namespace // Return true if the node has Ref abstract. @@ -472,24 +534,56 @@ bool IsCallNode(const AnfNodePtr &node) { return inputs[0]->isa() || (inputs[0]->isa() && IsValueNode(inputs[0])); } -std::vector FetchAllRealInputNodeByParameter(const AnfNodePtr &node) { - std::vector parameters; - const auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first; +bool IsSubCallNode(const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + const auto inputs = node->cast()->inputs(); + + if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { + return false; + } + + const auto &switch_layer_inputs = inputs[0]->cast()->inputs(); + const auto tuple_inputs = switch_layer_inputs[kSwitchLayerBranchPos]->cast()->inputs(); + if (tuple_inputs.size() <= kMakeTupleInputStartPos) { + return false; + } + + // Check whether the funcgraph called by the call node returns funcgraph or partial node. + FuncGraphPtr func_graph = nullptr; + if (AnfAlgo::CheckPrimitiveType(tuple_inputs[kMakeTupleInputStartPos], prim::kPrimPartial)) { + const auto &func_graph_node = tuple_inputs[kMakeTupleInputStartPos]->cast()->input(kPartialFuncGraphPos); + func_graph = GetValueNode(func_graph_node); + } else if (tuple_inputs[kMakeTupleInputStartPos]->isa() && + IsValueNode(tuple_inputs[kMakeTupleInputStartPos])) { + func_graph = GetValueNode(tuple_inputs[kMakeTupleInputStartPos]); + } + + const auto &output = func_graph->output(); + return AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial) || + (output->isa() && IsValueNode(output)); +} + +std::vector FetchAllRealInputNodeByParameter(const KernelWithIndex &node) { + std::vector parameters; + const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second); + const auto &real_node = real_node_with_index.first; if (real_node->isa()) { if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) { - parameters.emplace_back(real_node); + parameters.emplace_back(real_node_with_index); } } else if (HasAbstractMonad(real_node)) { return parameters; } else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) { const auto &inputs = real_node->cast()->inputs(); for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { - const auto &sub_parameters = FetchAllRealInputNodeByParameter(inputs[i]); + const auto &sub_parameters = FetchAllRealInputNodeByParameter({inputs[i], 0}); parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); } } else { - parameters.emplace_back(real_node); + parameters.emplace_back(real_node_with_index); } return parameters; } @@ -514,13 +608,15 @@ std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node) { AnfAlgo::CheckPrimitiveType(cnode_inputs[kSwitchLayerBranchPos], prim::kPrimMakeTuple)) { const auto &tuple_inputs = cnode_inputs[kSwitchLayerBranchPos]->cast()->inputs(); + // Fetch all funcgraphs in make tuple node. for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { - if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) { - func_graphs.emplace_back(GetFuncGraphFromPartial(tuple_inputs[i])); - } else if (IsValueNode(tuple_inputs[i])) { - func_graphs.emplace_back(GetValueNode(tuple_inputs[i])); + const auto func_graph = FetchFuncGraphInNode(tuple_inputs[i]); + if (func_graph != nullptr) { + func_graphs.emplace_back(func_graph); } } + } else if (IsCallNode(cnode)) { + return FetchFuncGraphbyCallNode(cnode); } else { MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString(); } @@ -563,7 +659,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector break; } total_num += call_output_num; - } else { + } else if (!HasAbstractMonad(inputs[i])) { ++total_num; } } @@ -612,16 +708,16 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node) { return kernel_graph->GetFrontAnfByBackendAnf(backend_node); } -AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) { +KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) { const auto &front_node = graph->GetFrontAnfByBackendAnf(backend_node); if (front_node != nullptr) { - return front_node; + return {front_node, 0}; } const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node); if (front_node_with_index.first == nullptr) { MS_LOG(EXCEPTION) << "Invalid parameter of kernel graph, parameter:" << AnfAlgo::GetNodeDebugString(backend_node); } - return front_node_with_index.first; + return front_node_with_index; } FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) { @@ -665,6 +761,8 @@ void ControlNodeParser::Parse(const std::vector &control_nodes, cons FetchHostParameterToWeight(real_to_formal_front_parameters); + FetchCallInputKernelGraph(graphs, device_contexts); + FetchFrontValueNode(control_nodes, graphs, device_contexts); FetchFrontToBackendKernel(graphs, device_contexts); @@ -757,7 +855,7 @@ AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &nod for (const auto &host_parameter_to_weight : host_parameter_to_weights_) { for (const auto &front_weight : host_parameter_to_weight.second) { if (front_weight == node) { - const auto &iter = front_to_backend_parameters_.find(front_weight); + const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first); if (iter != front_to_backend_parameters_.end()) { return iter->second.first; } @@ -868,6 +966,20 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector &contr } } } + + // When funcgraph called by call node returns to the value node, device addresses should be created for these + // value nodes. + for (const auto &call_node_to_backend_parameter : call_node_to_backend_parameters_) { + const auto func_graphs = FetchFuncGraphbyCallNode(call_node_to_backend_parameter.first.first); + for (const auto &func_graph : func_graphs) { + const auto &output = func_graph->output(); + if (output->isa() && GetFrontValueNodeDeviceContext(output) == nullptr) { + const auto &device_context = call_node_to_backend_parameter.second.second; + CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context); + front_value_nodes_.push_back({output, device_context}); + } + } + } } void ControlNodeParser::FetchFrontToFrontParameter( @@ -940,14 +1052,17 @@ void ControlNodeParser::FetchFrontToFrontParameter( } } else if (inputs[0]->isa()) { // Call node which the first input node is a switch or switchlayer node. - if ((!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) && - (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer))) { + if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { + std::vector call_inputs; + call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end()); + switch_input_parse(inputs[0], call_inputs); + } else if (IsCallNode(inputs[0])) { + continue; + } else { MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:" << AnfAlgo::GetNodeDebugString(inputs[0]); } - std::vector call_inputs; - call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end()); - switch_input_parse(inputs[0], call_inputs); } } } @@ -992,6 +1107,7 @@ void ControlNodeParser::FetchFuncGraphCallNum(const std::vector &con for (const auto &control_node : control_nodes) { if (IsCallNode(control_node)) { const auto &func_graphs = FetchFuncGraphbyCallNode(control_node); + for (const auto &func_graph : func_graphs) { MS_EXCEPTION_IF_NULL(func_graph); if (func_graph->output()->isa()) { @@ -1019,7 +1135,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vectorGetFrontNodeByInternalParameter(input); if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) { call_input_kernel_graphs_[graph] = device_context; - break; + call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context}; } } } @@ -1084,13 +1200,14 @@ std::vector FetchInputParameterbyControlNode(const AnfNodePtr &node, return parameters; } -std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) { - std::vector parameters; +std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) { + std::vector parameters; const auto &graph_parameters = graph->input_nodes(); for (const auto &graph_parameter : graph_parameters) { const auto &external_front_node = graph->GetFrontAnfByBackendAnf(graph_parameter); - const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(graph_parameter).first; + const auto &internal_front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter); + const auto &internal_front_node = internal_front_node_with_index.first; if (external_front_node == nullptr && internal_front_node == nullptr) { MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :" @@ -1098,9 +1215,9 @@ std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) continue; } - const auto &front_node = (external_front_node != nullptr) ? external_front_node : internal_front_node; - const auto real_front_node = AnfAlgo::VisitKernelWithReturnType(front_node, 0).first; - const auto &sub_parameters = FetchAllRealInputNodeByParameter(real_front_node); + const auto &front_node_with_index = + ((external_front_node != nullptr) ? KernelWithIndex(external_front_node, 0) : internal_front_node_with_index); + const auto &sub_parameters = FetchAllRealInputNodeByParameter(front_node_with_index); parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); } @@ -1191,6 +1308,8 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { // Switchlayer node. FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_); + } else if (IsCallNode(inputs[0])) { + continue; } else { MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString(); } @@ -1232,10 +1351,6 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vectorgraph_output_map(); for (const auto &output_pair : graph_output_map) { front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context}; - MS_LOG(DEBUG) << "Add front to backend kernel, front:" << AnfAlgo::GetNodeDebugString(output_pair.second.first) - << "index:" << output_pair.second.second << " addr:" << output_pair.second.first - << " second:" << AnfAlgo::GetNodeDebugString(output_pair.first.first) - << "index:" << output_pair.first.second << " addr:" << output_pair.first.first; } } } @@ -1246,6 +1361,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ std::set *results) { if (front_output->isa()) { (*results).insert({front_output, 0}); + const auto &iter = formal_to_real_parameters_.find(front_output); if (iter != formal_to_real_parameters_.end()) { for (const auto &node : iter->second) { @@ -1405,6 +1521,15 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector } } + for (const auto &host_parameter_to_weight : host_parameter_to_weights_) { + for (const auto &front_weight : host_parameter_to_weight.second) { + const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first); + if (iter != front_to_backend_parameters_.end()) { + formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0}); + } + } + } + for (const auto &func_graph_to_parameters : func_graph_to_parameters_) { const auto &func_graph = func_graph_to_parameters.first; std::vector graph_inputs; @@ -1453,7 +1578,6 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector &contro const auto &iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(node, 0)); if (iter != front_to_backend_kernels_.end()) { kernel_to_call_nodes_[iter->second.first.first] = control_node; - MS_LOG(DEBUG) << "Add auto monad control arrow for node:" << AnfAlgo::GetNodeDebugString(node); } } } diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index e3820a5631f..731cb2ace44 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -50,6 +50,9 @@ using RealToFormalNode = std::unordered_map> // 2. First input of node is a funcgraph value node. bool IsCallNode(const AnfNodePtr &node); +// Check if the call node is the input of another call node. +bool IsSubCallNode(const AnfNodePtr &node); + // Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph, // it is determined whether it is a weight. bool HasAbstractRef(const AnfNodePtr &node); @@ -66,7 +69,7 @@ AnfNodePtr GetFrontNodeByBackendNode(const AnfNodePtr &backend_node); // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the // corresponding cnode. -AnfNodePtr GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph); +KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraphPtr &graph); // Get the funcgraph to which the node belongs. FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node); @@ -75,7 +78,7 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node); std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node); // Get parameters in kernel graph. -std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph); +std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph); // ControlNodeParser is used to parse control nodes, and get the edges between nodes. class ControlNodeParser { @@ -205,6 +208,9 @@ class ControlNodeParser { // the input node of gather. FuncGraphToParameter func_graph_to_parameters_; + // The relationship between the valuenode inputs of the call node and the backend parameter + std::map> call_node_to_backend_parameters_; + // Branch id of funcgraph. // In control flow, funcgraph will be called in multiple places, and the output of funcgraph needs to return to // different places. Therefore, a branch id is created for each funcgraph. When funcgraph is called, the branch diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 1f36c0f91fa..f029432a72c 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -1180,7 +1180,7 @@ std::vector GraphScheduler::BuildSwitchActor(const GraphCompiler const auto &actor_name = control_node->DebugString(); auto switch_actor = std::make_shared(actor_name, graph_compiler_info.device_contexts_[0], control_node->cast(), branch_id, false); - switch_actor->Initialize(graph_compiler_info.control_node_parser_); + switch_actor->ParseInput(graph_compiler_info.control_node_parser_); // Fetch all the input nodes of switch actor. switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_); @@ -1197,7 +1197,7 @@ std::vector GraphScheduler::BuildSwitchActor(const GraphCompiler const auto &actor_name = return_node->DebugString(); auto switch_actor = std::make_shared(actor_name, graph_compiler_info.device_contexts_[0], return_node->cast(), kInvalidBranchID, true); - switch_actor->Initialize(graph_compiler_info.control_node_parser_); + switch_actor->ParseInput(graph_compiler_info.control_node_parser_); // Fetch all the input nodes of switch actor. switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_); @@ -1235,7 +1235,6 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler const auto &cnode = control_node->cast(); const auto &inputs = cnode->inputs(); const auto &return_node = func_graph->get_return(); - const auto &output_switch_aid = FetchActor(return_node->DebugString())->GetAID(); if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { // Root funcgraph does not need to create a gather actor. @@ -1245,21 +1244,26 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler } // If the output of funcgraph is a value node, no need to create gather actor. - if (inputs[kReturnInputPos]->isa()) { + if (inputs[kReturnInputPos]->isa() || + AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) { continue; } auto actor_name = func_graph->ToString(); - std::vector parameters; + std::vector parameters; for (const auto ¶meter : func_graph->get_inputs()) { if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) { continue; } - parameters.emplace_back(parameter); + parameters.push_back({parameter, 0}); } const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph); + const auto &output_switch_actor = FetchActor(return_node->DebugString()); + MS_EXCEPTION_IF_NULL(output_switch_actor); + const auto &output_switch_aid = output_switch_actor->GetAID(); + auto gather_actor = std::make_shared(actor_name, parameters, true, output_switch_aid, AID(), branch_id); gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_); @@ -1275,12 +1279,12 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler if (inputs[0]->isa() && IsValueNode(inputs[0])) { // Collect the parameters. - std::vector parameters; + std::vector parameters; for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa() && HasAbstractRef(inputs[i]))) { continue; } - parameters.emplace_back(inputs[i]); + parameters.push_back({inputs[i], 0}); } auto func_graph = control_node->func_graph(); @@ -1322,9 +1326,12 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf auto front_node = GetFrontNodeByBackendNode(from_kernel); if (from_kernel->isa() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) { - if (HasAbstractRef(from_kernel)) { - const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph); - to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get()); + const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph); + const auto &real_front_node_with_index = + AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second); + if (HasAbstractRef(real_front_node_with_index.first)) { + to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, + real_front_node_with_index.first.get()); return; } @@ -1332,9 +1339,8 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf const auto actor_name = graph->ToString(); auto actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); - const auto &real_front_node = GetFrontNodeByKernelGraph(from_kernel, graph); - LinkDataArrowForGatherActor(dynamic_cast(actor), real_front_node, to_actor, - to_kernel_with_input_idx.second); + LinkDataArrowForGatherActor(dynamic_cast(actor), to_actor, real_front_node_with_index, + to_kernel_with_input_idx); return; } @@ -1355,8 +1361,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node.get()); return; } - - LinkDataArrowForGatherActor(from_actor, front_node, to_actor, to_kernel_with_input_idx.second); + LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx); } else if (IsHostQueueDSActor(from_kernel, graph, tensor, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_)) { // Link the data arrows of host queue data source actor. @@ -2053,25 +2058,84 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector &control_nodes) { + for (const auto &node : control_nodes) { CNodePtr cnode = node->cast(); const auto &from_func_graph = node->func_graph(); auto inputs = cnode->inputs(); // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor. if (inputs[0]->isa()) { - auto actor = FetchActor(inputs[0]->DebugString()); - MS_EXCEPTION_IF_NULL(actor); - auto switch_actor = dynamic_cast(actor); - for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - if (HasAbstractMonad(inputs[i])) { - continue; + // Add the input of call node to switch actor. + if (IsCallNode(inputs[0])) { + const auto &sub_call_cnode = inputs[0]->cast(); + const auto &sub_inputs = sub_call_cnode->inputs(); + + if (AnfAlgo::CheckPrimitiveType(sub_inputs[0], prim::kPrimSwitchLayer)) { + auto actor = FetchActor(sub_inputs[0]->DebugString()); + MS_EXCEPTION_IF_NULL(actor); + auto switch_actor = dynamic_cast(actor); + + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + switch_actor->AddCommonInput(inputs[i]); + } + } + } else if (IsSubCallNode(cnode)) { + // Add the input of sub call node to switch actor. + auto actor = FetchActor(inputs[0]->DebugString()); + MS_EXCEPTION_IF_NULL(actor); + auto switch_actor = dynamic_cast(actor); + + const auto &tuple_node = inputs[0]->cast()->input(kSwitchLayerBranchPos); + const auto &tuple_inputs = tuple_node->cast()->inputs(); + + FuncGraphPtr func_graph = nullptr; + for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { + int pre_real_parameter_num = 0; + if (AnfAlgo::CheckPrimitiveType(tuple_inputs[i], prim::kPrimPartial)) { + pre_real_parameter_num = (tuple_inputs[i]->cast()->inputs().size() - kPartialInputStartPos); + func_graph = GetValueNode(tuple_inputs[i]->cast()->input(kPartialFuncGraphPos)); + } else { + func_graph = GetValueNode(tuple_inputs[i]); + } + const auto parameters = func_graph->parameters(); + const auto &output = func_graph->output(); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimPartial)) { + const auto &sub_partial_inputs = output->cast()->inputs(); + + // Check whether the input node of the sub call node needs to be added to the switch actor. Only when + // the final return is a partial node and the partial node needs this input, the input node is added + // to the switch actor/ + for (size_t j = kPartialInputStartPos; j < sub_partial_inputs.size(); ++j) { + const auto &real_partial_input = AnfAlgo::VisitKernelWithReturnType(sub_partial_inputs[j], 0).first; + const auto &iter = find(parameters.begin(), parameters.end(), real_partial_input); + + if ((iter != parameters.end()) && (iter - parameters.begin() >= pre_real_parameter_num) && + (iter - parameters.begin() < + SizeToInt(pre_real_parameter_num + inputs.size() - kCallInputStartPos))) { + size_t pos = iter - parameters.begin() - pre_real_parameter_num + kCallInputStartPos; + switch_actor->AddSingleInput(inputs[pos], i - 1); + } + } + } + } + } else { + auto actor = FetchActor(inputs[0]->DebugString()); + MS_EXCEPTION_IF_NULL(actor); + auto switch_actor = dynamic_cast(actor); + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + if (HasAbstractMonad(inputs[i])) { + continue; + } + switch_actor->AddCommonInput(inputs[i]); } - switch_actor->AddCommonInput(inputs[i]); } } } +} + +void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set) { + PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_); for (const auto &node : graph_compiler_info.control_nodes_) { CNodePtr cnode = node->cast(); @@ -2142,11 +2206,10 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi MS_EXCEPTION_IF_NULL(actor); auto gather_actor = dynamic_cast(actor); - for (const auto &input_node : gather_actor->data_nodes_) { + for (const auto &input_with_index : gather_actor->data_nodes_) { const auto &from_func_graph = kernel_graph->GetFuncGraph(); - const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0); LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, - gather_actor->FetchDataNodePosition(input_node)); + gather_actor->FetchDataNodePosition(input_with_index)); } } LinkBranchArrowForSwitchActor(graph_compiler_info, actor_set); @@ -2162,15 +2225,16 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set); } -void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, - KernelActor *to_actor, const size_t to_index) { +void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor, + const KernelWithIndex &front_node_with_index, + const KernelWithIndex &to_node_with_index) { MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(to_actor); - MS_EXCEPTION_IF_NULL(front_node); + MS_EXCEPTION_IF_NULL(front_node_with_index.first); - auto position = from_actor->FetchDataNodePosition(front_node); + auto position = from_actor->FetchDataNodePosition(front_node_with_index); - auto op_arrow = std::make_shared(position, to_actor->GetAID(), to_index); + auto op_arrow = std::make_shared(position, to_actor->GetAID(), to_node_with_index.second); from_actor->output_data_arrows_.emplace_back(op_arrow); to_actor->input_datas_num_++; } @@ -2181,13 +2245,18 @@ void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_w // Fetch all the funcgraph that call node would call. const auto cnode = call_node_with_index.first->cast(); std::vector func_graphs = FetchFuncGraphbyCallNode(cnode); + const auto &call_inputs = cnode->inputs(); + auto switch_node = call_inputs[0]; + if (IsCallNode(switch_node)) { + switch_node = call_inputs[0]->cast()->input(0); + } // Collect the output of each funcgraph. for (const auto &func_graph : func_graphs) { if (func_graph->output()->isa()) { - const auto &call_inputs = cnode->inputs(); - if (AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch)) { - const auto &actor_name = call_inputs[0]->DebugString(); + if (AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(switch_node, prim::kPrimSwitchLayer)) { + const auto &actor_name = switch_node->DebugString(); const auto &actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); @@ -2254,7 +2323,9 @@ void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const } for (size_t i = start_branch; i < max_branch; ++i) { if (from_actor->branch_inputs_pos_[i].size() <= from_index) { - MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i; + MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i + << " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size() + << " to actor:" << to_actor->GetAID() << " to index:" << to_index; } auto op_arrow = std::make_shared(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index); @@ -2277,7 +2348,7 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c } else if (IsGatherActor(input_node, actor_name_to_actor_)) { // The actor input is a parameter in gather actor. auto from_actor = dynamic_cast(actor_name_to_actor_[input_node->func_graph()->ToString()]); - auto position = from_actor->FetchDataNodePosition(input_node); + auto position = from_actor->FetchDataNodePosition({input_node, 0}); auto op_arrow = std::make_shared(position, to_actor->GetAID(), to_index); from_actor->output_data_arrows_.emplace_back(op_arrow); } else if (IsSwitchActor(input_node)) { @@ -2338,7 +2409,8 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false); UpdateRefCount(device_tensor.get(), true); } else { - MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node); + MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node) + << " to actor:" << to_actor->GetAID(); } } @@ -2559,65 +2631,6 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap } } -void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info, - const ActorSet *actor_set) { - MS_EXCEPTION_IF_NULL(actor_set); - OutputActor *to_actor = actor_set->output_actor_.get(); - MS_EXCEPTION_IF_NULL(to_actor); - - for (const auto &func_graph_to_branch_id : graph_compiler_info.control_node_parser_->func_graph_to_branch_id_) { - if (func_graph_to_branch_id.second == kMainBranchID) { - continue; - } - - const auto &func_graph = func_graph_to_branch_id.first; - auto actor = FetchActor(func_graph->ToString()); - MS_EXCEPTION_IF_NULL(actor); - auto gather_actor = dynamic_cast(actor); - - for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) { - const auto front_node = gather_actor->data_nodes_[i]; - auto origin_output_with_index = KernelWithIndex(front_node, 0); - const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index); - if (iter == graph_compiler_info.origin_outputs_order_.end()) { - continue; - } - - for (auto &output_position : iter->second) { - auto op_arrow = std::make_shared(i, to_actor->GetAID(), output_position); - gather_actor->output_result_arrows_.emplace_back(op_arrow); - const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node]; - if (backend_nodes.empty()) { - MS_LOG(EXCEPTION) << "No backend node for data node:" << AnfAlgo::GetNodeDebugString(front_node); - } - - const auto &backend_node = backend_nodes[0].first; - if (backend_node->isa()) { - std::string actor_name = graph_compiler_info.name_ + "_HostDSActor"; - auto ds_op_actor = FetchActor(actor_name); - MS_EXCEPTION_IF_NULL(ds_op_actor); - auto host_ds_actor = dynamic_cast(ds_op_actor); - MS_EXCEPTION_IF_NULL(host_ds_actor); - - const auto &data_nodes = host_ds_actor->data_nodes_; - const auto &node_iter = find(data_nodes.begin(), data_nodes.end(), backend_node); - if (node_iter == data_nodes.end()) { - MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:" - << AnfAlgo::GetNodeDebugString(backend_node); - } - to_actor->device_contexts_[output_position] = host_ds_actor->device_contexts_[node_iter - data_nodes.begin()]; - } else { - auto actor_base = FetchActor(backend_node->fullname_with_scope()); - MS_EXCEPTION_IF_NULL(actor_base); - auto kernel_actor = dynamic_cast(actor_base); - MS_EXCEPTION_IF_NULL(kernel_actor); - to_actor->device_contexts_[output_position] = kernel_actor->device_context_; - } - } - } - } -} - bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const { MS_EXCEPTION_IF_NULL(actor_set); // Check the data source actors. @@ -3098,7 +3111,7 @@ void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &of ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n"; for (const auto &node : actor->data_nodes_) { - ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n'; + ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n'; } ofs << "\t\tactor front to backend node:\n"; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index f2313ee0a14..57d39fe8fe0 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -233,8 +233,9 @@ class GraphScheduler { // 4. The processing of control flow linking. void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set); - void LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, KernelActor *to_actor, - const size_t to_index); + void LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor, + const KernelWithIndex &front_node_with_index, + const KernelWithIndex &to_node_with_index); void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *actor); // Connect the input of the actor. void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node, @@ -263,6 +264,9 @@ class GraphScheduler { const ControlNodeParserPtr &control_node_parser, const std::vector &origin_parameters, const std::vector &tensors, std::vector *host_tensors); + // Add input for switch actor. Since part of the input of funcgraph is on call node, these inputs need to be added + // to switch actor. + void PrepareInputNodeForSwitchActor(const std::vector &control_nodes); // The processing of actors link dynamically. // Analyze necessary input data of current actor, generate and cache op arrow