diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc index 61662dafded..48074ec9c09 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.cc @@ -16,6 +16,7 @@ #include "runtime/framework/actor/gather_actor.h" #include "runtime/framework/actor/output_actor.h" +#include "runtime/framework/actor/switch_actor.h" #include "runtime/framework/actor/memory_manager_actor.h" #include "runtime/framework/actor/loop_count_actor.h" #include "mindrt/include/async/async.h" @@ -44,7 +45,7 @@ void GatherActor::Init() { size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const { const auto &iter = find(data_nodes_.begin(), data_nodes_.end(), data_node); if (iter == data_nodes_.end()) { - MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() + MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node) << " is not exist in gather actor:" << GetAID(); } return iter - data_nodes_.begin(); @@ -52,9 +53,8 @@ size_t GatherActor::FetchDataNodePosition(const AnfNodePtr &data_node) const { void GatherActor::RunOpData(OpData *input_data, OpContext *context) { MS_EXCEPTION_IF_NULL(context); - auto sequential_num = context->sequential_num_; - input_op_datas_[sequential_num].emplace_back(input_data); + input_data_[sequential_num][input_data->index_].push(input_data->data_); if (CheckLaunchCondition(context)) { FetchInputDeviceTensor(context); @@ -63,6 +63,29 @@ void GatherActor::RunOpData(OpData *input_data, OpContext *context) { + MS_EXCEPTION_IF_NULL(context); + auto &sequential_num = context->sequential_num_; + input_op_controls_[sequential_num].emplace_back(input_control); + + if (CheckLaunchCondition(context)) { + FetchInputDeviceTensor(context); + EraseInput(context); + SendOutput(context); + } +} + +void GatherActor::CollectBranchId(const int branch_id, OpContext *context) { + MS_EXCEPTION_IF_NULL(context); + auto &sequential_num = context->sequential_num_; + input_branch_ids_[sequential_num] = branch_id; + if (CheckLaunchCondition(context)) { + FetchInputDeviceTensor(context); + EraseInput(context); + SendOutput(context); + } +} + void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const ControlNodeParserPtr &parser) { for (const auto &input : func_graph->get_inputs()) { // Monad input would not send to gather actor. @@ -76,20 +99,20 @@ void GatherActor::FetchBackendInputNode(const FuncGraphPtr &func_graph, const Co void GatherActor::SendOutput(OpContext *context) const { MS_EXCEPTION_IF_NULL(context); - - // Branch arrow and result arrow must be executed before the data arrow and control arrow, otherwise the output - // actor may receive the loop count message first and cause the output to be abnormal. - if (branch_id_ > kInvalidBranchID) { - Async(loop_count_aid_, &LoopCountActor::CollectBranchId, branch_id_, context); - Async(output_aid_, &OutputActor::CollectBranchId, branch_id_, context); + // Send output branch id. + if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), switch_aid_) != output_branch_arrows_.end()) { + int branch_id = input_branch_id_; + Async(switch_aid_, &SwitchActor::CollectBranchId, branch_id, context); + } + if (find(output_branch_arrows_.begin(), output_branch_arrows_.end(), gather_aid_) != output_branch_arrows_.end()) { + Async(gather_aid_, &GatherActor::CollectBranchId, local_branch_id_, context); } - // Send graph output result. + // Send output result. for (const auto &result_arrow : output_result_arrows_) { MS_EXCEPTION_IF_NULL(result_arrow); size_t from_index = result_arrow->from_output_index_; const auto &front_node = data_nodes_[from_index]; - for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) { if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() == input_device_tensors_[from_index]) { @@ -115,15 +138,28 @@ void GatherActor::SendOutput(OpContext *context) const { void GatherActor::FetchInputDeviceTensor(OpContext *context) { MS_EXCEPTION_IF_NULL(context); - - auto data_iter = input_op_datas_.find(context->sequential_num_); - if (data_iter != input_op_datas_.end()) { + auto data_iter = input_data_.find(context->sequential_num_); + if (data_iter != input_data_.end()) { for (auto &input_data : data_iter->second) { - MS_EXCEPTION_IF_NULL(input_data); - input_device_tensors_[input_data->index_] = input_data->data_; + input_device_tensors_[input_data.first] = input_data.second.top(); + input_data.second.pop(); } } + for (const auto &device_tensor_store_key : device_tensor_store_keys_) { + const auto &device_context = device_contexts_[device_tensor_store_key.first]; + MS_EXCEPTION_IF_NULL(device_context); + auto device_tensor = + DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context->GetDeviceAddressType()); + if (device_tensor == nullptr) { + std::string error_info = + GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() + + ", device type:" + std::to_string(static_cast(device_context->GetDeviceAddressType())); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); + } + input_device_tensors_[device_tensor_store_key.first] = device_tensor; + } + for (size_t i = 0; i < output_data_by_output_index_.size(); ++i) { const auto &data = input_device_tensors_[i]; for (auto &output_data : output_data_by_output_index_[i]) { @@ -131,20 +167,31 @@ void GatherActor::FetchInputDeviceTensor(OpContext *context) { output_data->data_ = data; } } + + if (need_branch_id_input_) { + input_branch_id_ = input_branch_ids_[context->sequential_num_]; + } } bool GatherActor::CheckLaunchCondition(OpContext *context) const { MS_EXCEPTION_IF_NULL(context); + + // Fetch input data. if (input_datas_num_ != 0) { - auto data_iter = input_op_datas_.find(context->sequential_num_); - if (data_iter == input_op_datas_.end()) { + auto data_iter = input_data_.find(context->sequential_num_); + if (data_iter == input_data_.end()) { return false; } - if (data_iter->second.size() != input_datas_num_) { + if (data_iter->second.size() + device_tensor_store_keys_.size() != input_datas_num_) { + return false; + } + if (std::any_of(data_iter->second.begin(), data_iter->second.end(), + [](const auto &input_stack) { return input_stack.second.empty(); })) { return false; } } + // Fetch input control. if (input_controls_num_ != 0) { auto control_iter = input_op_controls_.find(context->sequential_num_); if (control_iter == input_op_controls_.end()) { @@ -154,19 +201,32 @@ bool GatherActor::CheckLaunchCondition(OpContext *context) const { return false; } } + + // Fetch input branch id. + if (need_branch_id_input_) { + auto branch_id_iter = input_branch_ids_.find(context->sequential_num_); + if (branch_id_iter == input_branch_ids_.end()) { + return false; + } + } return true; } void GatherActor::EraseInput(OpContext *context) { MS_EXCEPTION_IF_NULL(context); - if (input_datas_num_ != 0) { - auto ret = input_op_datas_.erase(context->sequential_num_); + + // Erase input data. + auto data_iter = input_data_.find(context->sequential_num_); + if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(), + [](const auto &input_data) { return input_data.second.empty(); })) { + auto ret = input_data_.erase(context->sequential_num_); if (ret == 0) { std::string error_info = "Erase input data failed: " + GetAID().Name(); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } } + // Erase input control. if (input_controls_num_ != 0) { auto ret = input_op_controls_.erase(context->sequential_num_); if (ret == 0) { @@ -174,6 +234,15 @@ void GatherActor::EraseInput(OpContext *context) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } } + + // Erase input branch id. + if (need_branch_id_input_) { + auto ret = input_branch_ids_.erase(context->sequential_num_); + if (ret == 0) { + std::string error_info = "Erase input branch id failed: " + GetAID().Name(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); + } + } } } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h index c2aa0bd057a..578093aa25b 100644 --- a/mindspore/ccsrc/runtime/framework/actor/gather_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/gather_actor.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include "runtime/framework/device_tensor_store.h" @@ -36,20 +37,37 @@ namespace runtime { constexpr size_t kReturnInputPos = 1; -// Gather actor is the entrance of sub funcgraph. Graph input is sent to it and sent to other actors by gather actor. +// Gather actor is used in three places: +// 1. Entrance of sub funcgraph +// 2. call node which input0 is a funcgraph +// 3. There is some call nodes in the inputs of kernel graph. +// Gather actor will be used in the control flow. When the subgraph is called, the real parameters need to be put +// together and sent to the subgraph. At the same time, the entry of the subgraph needs to accept input data. +// Special in recursion, general inputs and call inputs of the kernel graph are used in stack mode, it needs to be +// collected at the entrance of the kernel graph. class GatherActor : public OpActor { public: - GatherActor(const std::string &name, const std::vector ¶meters, const AID loop_count_aid, - const AID output_aid) - : OpActor(name), data_nodes_(parameters), loop_count_aid_(loop_count_aid), output_aid_(output_aid) {} + GatherActor(const std::string &name, const std::vector ¶meters, const bool need_branch_id_input, + const AID switch_aid, const AID gather_aid, const int branch_id) + : OpActor(name), + data_nodes_(parameters), + need_branch_id_input_(need_branch_id_input), + switch_aid_(switch_aid), + gather_aid_(gather_aid), + local_branch_id_(branch_id) { + device_contexts_.resize(parameters.size()); + } ~GatherActor() override = default; // Get the index of the parameter, the data_node needs to be the front node. size_t FetchDataNodePosition(const AnfNodePtr &data_node) const; - // The kernel actor run when receive the input data. + // The gather actor run when receive the input data. void RunOpData(OpData *input_data, OpContext *context) override; - + // The gather actor run when receive the input control. + void RunOpControl(AID *input_control, OpContext *context) override; + // The gather actor run when receive the input branch id. + void CollectBranchId(const int branch_id, OpContext *context); void Init() override; private: @@ -66,13 +84,33 @@ class GatherActor : public OpActor { // The device tensors for launch. std::vector input_device_tensors_; + // The branch if for current step. + int input_branch_id_; - DeviceContext *device_contexts_; + // Input data. + std::unordered_map>> input_data_; + // Input branch ids is used to record the id corresponding receive from gather actor. + // In control flow, sub funcgraph may be called in multiple places, and the output must be return to different + // places. Therefore, the output of each subgraph will be connected to a switch actor, and the caller will send + // its branch id to the gather actor of the subgraph. Then branch id will be sent by the gather actor to the + // switch actor connected to the output. + std::unordered_map input_branch_ids_; + // Output data. + // Cache unique output data by output index to modify the output data effectively. + std::vector>> output_data_by_output_index_; + // The output_data_ corresponds to the output_data_arrows_ one by one. + std::vector *> output_data_; + + // Output arrows. std::vector output_result_arrows_; + std::vector output_branch_arrows_; // Parameters of sub funcgraph, which is the front node. std::vector data_nodes_; + std::vector device_contexts_; + // Pair points to the dependent device tensor store, anfNode is the key of the device tensor store. + std::vector> device_tensor_store_keys_; // When the output is a parameter of the subgraph, the gather actor needs to send the anfnode to the output actor, // so all the nodes that may send the device tensor to gather actor are recorded. When the anfnode needs to be sent @@ -83,18 +121,19 @@ class GatherActor : public OpActor { size_t input_datas_num_{0}; // The dependent input controls number. size_t input_controls_num_{0}; + // Whether it needs to accept the branch id. When the gather actor is the input of the subgraph, it needs to receive + // branch id sent by the subgraph caller, which will be true at this time. + bool need_branch_id_input_; - const AID loop_count_aid_; - const AID output_aid_; + // Actor id that needs to send the branch id to it. + // When the actor is corresponding to call node, the branch id needs to be sent to the input gather actor and output + // switch actor of the called funcgraph. When the actor is the entrance of the funcgraph, the gather actor id is + // empty, just need to send branch id to its output switch actor. + const AID switch_aid_; + const AID gather_aid_; - // Cache unique output data by output index to modify the output data effectively. - std::vector>> output_data_by_output_index_; - // The output_data_ corresponds to the output_data_arrows_ one by one. - std::vector *> output_data_; - - // When the result of the graph is sent to the output actor, the gather actor of the graph needs - // to send branch_id to the output actor to determine the corresponding weight. - int branch_id_{kInvalidBranchID}; + // The branch id corresponding to the funcgraph to which the gather actor belongs. + int local_branch_id_; }; using GatherActorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc index 4258639575d..cfe8542e275 100644 --- a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc @@ -77,6 +77,9 @@ void KernelActor::RunOpData(OpData *input_data, OpContextsequential_num_; input_op_datas_[sequential_num].emplace_back(input_data); + if (input_data->data_ == nullptr) { + MS_LOG(EXCEPTION) << "Input data of actor:" << GetAID() << " num:" << input_data->index_ << " is empty"; + } // When all the inputs are collected, then allocate memory and callback launch. if (CheckLaunchCondition(context)) { // Infer kernel shape and update abstract info for dynamic shape kernel. @@ -245,7 +248,7 @@ void KernelActor::FetchInputDeviceTensor(OpContext *context) { DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType()); if (device_tensor == nullptr) { std::string error_info = - GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() + + GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->DebugString() + ", device type:" + std::to_string(static_cast(device_context_->GetDeviceAddressType())); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } diff --git a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc index 9ea533d944a..cfc5b03626c 100644 --- a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc @@ -86,16 +86,6 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext *c MS_EXCEPTION_IF_NULL(context); auto sequential_num = context->sequential_num_; input_op_controls_[sequential_num].emplace_back(input_control); - - if (CheckLoopCountIncreaseCondition(context)) { - IncreaseLoopCount(context); - } -} - -void LoopCountActor::CollectBranchId(const int branch_id, OpContext *context) { - MS_EXCEPTION_IF_NULL(context); - branch_id_ = branch_id; - if (CheckLoopCountIncreaseCondition(context)) { IncreaseLoopCount(context); } @@ -138,7 +128,6 @@ void LoopCountActor::SendOutput(OpContext *context) { if (recorder_aid_ != nullptr) { Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context); } - SendMemoryAllocReq(context); } @@ -180,14 +169,8 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext *context) { bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext *context) { MS_EXCEPTION_IF_NULL(context); auto sequential_num = context->sequential_num_; - if (branch_id_ == kInvalidBranchID) { - return false; - } - if (branch_id_ >= SizeToInt(branch_id_to_input_controls_num_.size())) { - MS_LOG(ERROR) << "Branch id is invalid, id:" << branch_id_; - } - return input_op_controls_[sequential_num].size() == branch_id_to_input_controls_num_[branch_id_]; + return input_op_controls_[sequential_num].size() == input_controls_num_; } } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h index e839ea80630..cdb896126bb 100644 --- a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h @@ -40,11 +40,10 @@ class LoopCountActor : public DebugAwareActor { loop_count_(loop_count), current_count_(0), total_running_count_(0), + input_controls_num_(0), memory_manager_aid_(memory_manager_aid), debug_aid_(debug_aid), - recorder_aid_(recorder_aid) { - branch_id_to_input_controls_num_[kMainBranchID] = 0; - } + recorder_aid_(recorder_aid) {} ~LoopCountActor() override = default; @@ -63,11 +62,6 @@ class LoopCountActor : public DebugAwareActor { // The callback after debug finished. void OnDebugFinish(OpContext *context) override; - // In control flow, there are multi-branch output situations. In this case, the gather actor will be numbered - // branch id, and the branch id will be sent to the loop count actor during operation. The interface is used - // to receive the branch id message. - void CollectBranchId(const int branch_id_, OpContext *context); - private: friend class GraphScheduler; @@ -84,7 +78,7 @@ class LoopCountActor : public DebugAwareActor { // The dependent input controls number. // In the multi-branch output scenario of the control flow, the control of each branch needs to be recorded // separately with the branch id as the key. When the output has only one branch, the branch id is 0. - std::unordered_map branch_id_to_input_controls_num_; + size_t input_controls_num_; // The output controls contain the data source actors and the no input kernel actors and output actor. std::vector data_source_aids_; @@ -98,10 +92,6 @@ class LoopCountActor : public DebugAwareActor { // The id of recorder actor. Send message to it for clearing recorder info before loop count actor exits. const AID *recorder_aid_; - // When the result of the graph is sent to the output actor, the gather actor of the graph needs - // to send branch_id to the output actor to determine the corresponding weight. - int branch_id_{kMainBranchID}; - // The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair // expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need // continuous memory. diff --git a/mindspore/ccsrc/runtime/framework/actor/output_actor.cc b/mindspore/ccsrc/runtime/framework/actor/output_actor.cc index 23278636ff5..865469bd7cc 100644 --- a/mindspore/ccsrc/runtime/framework/actor/output_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/output_actor.cc @@ -44,28 +44,25 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, void OutputActor::Init() { // Set the number of actor running dependent messages. - if ((!need_loop_count_) && (device_tensor_store_keys_.size() == 1)) { - running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_[kMainBranchID].size()); + if ((!need_loop_count_)) { + running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size()); } } void OutputActor::CollectLoopCount(size_t loop_count, OpContext *context) { MS_EXCEPTION_IF_NULL(context); - if (branch_id_ == kInvalidBranchID) { - MS_LOG(EXCEPTION) << "Invalid branch id for output actor."; - } + current_count_ = loop_count; if (loop_count_ == current_count_) { - if (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() != outputs_num_) { - std::string error_info = - "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) + - ", the current outputs num: " + std::to_string(current_outputs_num_) + - ", the device tensor store num: " + std::to_string(device_tensor_store_keys_[branch_id_].size()); + if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) { + std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) + + ", the current outputs num: " + std::to_string(current_outputs_num_) + + ", the device tensor store num: " + std::to_string(device_tensor_store_keys_.size()); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); } // Because device tensor store can't send data, so fetch the output result of device tensor store in running end. - for (const auto &device_tensor_store_key : device_tensor_store_keys_[branch_id_]) { + for (const auto &device_tensor_store_key : device_tensor_store_keys_) { if (device_tensor_store_key.first >= outputs_.size()) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range."); } @@ -108,16 +105,10 @@ void OutputActor::UpdateOutputDeviceAddress() { output_nodes_.resize(outputs_num_); } -void OutputActor::CollectBranchId(const int branch_id, OpContext *context) { - MS_EXCEPTION_IF_NULL(context); - branch_id_ = branch_id; -} - void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position, OpContext *context) { MS_EXCEPTION_IF_NULL(output_node); MS_EXCEPTION_IF_NULL(context); - // Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1". if (loop_count_ - current_count_ != 1) { return; @@ -132,7 +123,7 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind // Save the output nodes to clear the device tensor in the running end. output_nodes_[output_position] = KernelWithIndex(output_node, output_index); // There is no loop count actor in step mode, need trigger call CollectLoopCount to replace old output device tensors. - if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_[branch_id_].size() == outputs_num_)) { + if (!need_loop_count_ && (current_outputs_num_ + device_tensor_store_keys_.size() == outputs_num_)) { CollectLoopCount(++current_count_, context); } } diff --git a/mindspore/ccsrc/runtime/framework/actor/output_actor.h b/mindspore/ccsrc/runtime/framework/actor/output_actor.h index 0980cffce45..e1e2d96121e 100644 --- a/mindspore/ccsrc/runtime/framework/actor/output_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/output_actor.h @@ -46,12 +46,10 @@ class OutputActor : public OpActor { outputs_num_(outputs_num), current_outputs_num_(0), need_loop_count_(need_loop_count), - branch_id_(kMainBranchID), running_dependent_msg_num_(1) { outputs_.resize(outputs_num); output_nodes_.resize(outputs_num); device_contexts_.resize(outputs_num); - device_tensor_store_keys_[kMainBranchID] = std::vector>(); } ~OutputActor() override = default; @@ -65,8 +63,6 @@ class OutputActor : public OpActor { void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position, OpContext *context); - void CollectBranchId(const int branch_id, OpContext *context); - // The graph output need be set new device address every step or loop, to avoid that the device address // context of tensor be rewritten in the next step or next loop. void UpdateOutputDeviceAddress(); @@ -88,16 +84,11 @@ class OutputActor : public OpActor { size_t outputs_num_; size_t current_outputs_num_; bool need_loop_count_; - int branch_id_; // The dependent messages number of actor running. int running_dependent_msg_num_; - // Pair> points to the dependent device tensor store, branch_id is the output branch id. - // In general, the branch id is 0, which means there is only one output branch in the actor set. When there are - // multiple possible output branches in the actor set, different branch ids correspond to their own related nodes. - // The index is the position of node in the output, node is the key of the device tensor store. - std::unordered_map>> device_tensor_store_keys_; + std::vector> device_tensor_store_keys_; }; using OutputActorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc index b1f5d23efa5..ec41691118e 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.cc @@ -16,6 +16,7 @@ #include "runtime/framework/actor/switch_actor.h" #include "runtime/framework/actor/output_actor.h" +#include "runtime/framework/actor/gather_actor.h" #include "runtime/framework/actor/memory_manager_actor.h" #include "mindrt/include/async/async.h" #include "abstract/utils.h" @@ -39,10 +40,10 @@ void SwitchActor::Init() { void SwitchActor::RunOpData(OpData *input_data, OpContext *context) { MS_EXCEPTION_IF_NULL(context); - auto sequential_num = context->sequential_num_; - input_op_datas_[sequential_num].emplace_back(input_data); + const auto &sequential_num = context->sequential_num_; + auto &input_datas = input_data_[sequential_num]; + input_datas[input_data->index_].push(input_data->data_); - // When all the inputs are collected, then allocate memory and callback launch. if (CheckLaunchCondition(context)) { FetchInputDeviceTensor(context); EraseInput(context); @@ -50,14 +51,38 @@ void SwitchActor::RunOpData(OpData *input_data, OpContext *context) { + MS_EXCEPTION_IF_NULL(context); + auto &sequential_num = context->sequential_num_; + if (input_controls_[sequential_num].find(input_control) == input_controls_[sequential_num].end()) { + input_controls_[sequential_num][input_control] = 0; + } + input_controls_[sequential_num][input_control]++; + + if (CheckLaunchCondition(context)) { + FetchInputDeviceTensor(context); + EraseInput(context); + SendOutput(context); + } +} + +void SwitchActor::CollectBranchId(const int branch_id, OpContext *context) { + MS_EXCEPTION_IF_NULL(context); + auto &sequential_num = context->sequential_num_; + input_branch_ids_[sequential_num].push(branch_id); +} + +void SwitchActor::Initialize(const ControlNodeParserPtr &parser) { std::vector inputs = node_->inputs(); if (IsPrimitive(inputs[0], prim::kPrimSwitch)) { InitSwitch(); + } else if (IsPrimitive(inputs[0], prim::kPrimReturn)) { + InitReturn(parser); } else { InitSwitchLayer(); } + backend_parameters_.resize(input_nodes_.size()); } void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) { @@ -88,6 +113,23 @@ void SwitchActor::InitPartial(const AnfNodePtr &node, const size_t branch_id) { } } +void SwitchActor::InitVectorSize(const size_t num) { + branch_inputs_pos_.resize(num); + branch_func_graph_.resize(num); + output_branch_arrows_.resize(num); + output_branch_result_arrows_.resize(num); + output_branch_control_arrows_.resize(num); + output_branch_branch_arrows_.resize(num); +} + +void SwitchActor::InitReturn(const ControlNodeParserPtr &parser) { + const auto &func_graph = node_->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + const auto &call_num = parser->GetCallNumByFuncGraph(func_graph); + InitVectorSize(call_num); + AddCommonInput(func_graph->output()); +} + void SwitchActor::InitSwitch() { // The inputs of the switch node: // [0] ValueNode kSwitch. @@ -99,14 +141,10 @@ void SwitchActor::InitSwitch() { MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is not equal 4"; } - branch_total_inputs_.resize(kSwitchPartialNum); - branch_inputs_pos_.resize(kSwitchPartialNum); - branch_func_graph_.resize(kSwitchPartialNum); - output_branch_arrows_.resize(kSwitchPartialNum); - output_branch_result_arrows_.resize(kSwitchPartialNum); - output_branch_control_arrows_.resize(kSwitchPartialNum); + InitVectorSize(kSwitchPartialNum); - input_nodes_.push_back(inputs[kSwitchCondPos]); + const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0); + input_nodes_.push_back(cond_node); input_datas_num_++; // Init the two branches of switch node. InitPartial(inputs[kSwitchFalseBranchPos], static_cast(false)); @@ -123,16 +161,13 @@ void SwitchActor::InitSwitchLayer() { MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitchLayer->name() << " is not equal 3"; } - input_nodes_.push_back(inputs[kSwitchLayerCondPos]); + const auto cond_node = AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchLayerCondPos], 0); + input_nodes_.push_back(cond_node); + input_datas_num_++; // The second input of SwitchLayer is maketuple node, which includes all branches. auto branch_nodes = inputs[kSwitchLayerBranchPos]->cast()->inputs(); - branch_total_inputs_.resize(branch_nodes.size() - 1); - branch_inputs_pos_.resize(branch_nodes.size() - 1); - branch_func_graph_.resize(branch_nodes.size() - 1); - output_branch_arrows_.resize(branch_nodes.size() - 1); - output_branch_result_arrows_.resize(branch_nodes.size() - 1); - output_branch_control_arrows_.resize(branch_nodes.size() - 1); + InitVectorSize(branch_nodes.size() - 1); // Parse all branches. for (size_t i = 1; i < branch_nodes.size(); ++i) { @@ -151,41 +186,92 @@ void SwitchActor::AddCommonInput(const AnfNodePtr &node) { } size_t SwitchActor::FetchDataNodePosition(const AnfNodePtr &data_node) const { - const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node); + const auto data_node_with_index = AnfAlgo::VisitKernelWithReturnType(data_node, 0); + const auto &iter = find(input_nodes_.begin(), input_nodes_.end(), data_node_with_index); if (iter == input_nodes_.end()) { - MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() - << " is not exist in gather actor:" << GetAID(); + MS_LOG(EXCEPTION) << "Data node: " << AnfAlgo::GetNodeDebugString(data_node) + << " is not exist in switch actor:" << GetAID(); } return iter - input_nodes_.begin(); } -void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) { - branch_total_inputs_[branch].push_back(node); +void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t branch) { + const auto &node = node_with_index.first; - if (node->isa() && (!HasAbstractMonad(node))) { + // Add weight and value node. + if ((AnfAlgo::CheckPrimitiveType(node_, prim::kPrimReturn) && HasAbstractRef(node)) || node->isa()) { + const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index); + if (iter != input_nodes_.end()) { + branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin()); + return; + } device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()}); branch_inputs_pos_[branch].push_back(input_nodes_.size()); - input_nodes_.push_back(node); + input_nodes_.push_back(node_with_index); return; } - // Switch actor only receives parameter, updatestate node output is U, need to be skipped. - if (IsPersistentDeviceTensor(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { + // Output of updatestate node is U, need to be skipped. + if (HasAbstractRef(node)) { return; } - auto iter = find(input_nodes_.begin(), input_nodes_.end(), node); + // Add parameter. + auto iter = find(input_nodes_.begin(), input_nodes_.end(), node_with_index); if (iter == input_nodes_.end()) { branch_inputs_pos_[branch].push_back(input_nodes_.size()); - input_nodes_.push_back(node); + input_nodes_.push_back(node_with_index); ++input_datas_num_; } else { branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin()); } } -size_t SwitchActor::GetIndex() { +void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState) || HasAbstractMonad(node)) { + return; + } + + const auto &real_input = AnfAlgo::VisitKernelWithReturnType(node, 0); + + if (AnfAlgo::CheckPrimitiveType(real_input.first, prim::kPrimMakeTuple)) { + const auto &inputs = real_input.first->cast()->inputs(); + for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { + AddInput(inputs[i], branch); + } + } else if (IsCallNode(real_input.first)) { + std::vector call_nodes; + const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes); + + if (call_output_num <= 0) { + MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first); + } + for (size_t i = 0; i < call_output_num; ++i) { + AddInput({real_input.first, i}, branch); + } + } else { + AddInput(real_input, branch); + } +} + +size_t SwitchActor::GetIndex(OpContext *context) { + if (need_branch_id_input_) { + if (input_branch_ids_.find(context->sequential_num_) == input_branch_ids_.end() || + input_branch_ids_[context->sequential_num_].empty()) { + MS_LOG(EXCEPTION) << "Invalid branch id for actor:" << GetAID(); + } + size_t branch_id = input_branch_ids_[context->sequential_num_].top(); + input_branch_ids_[context->sequential_num_].pop(); + if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) { + MS_LOG(EXCEPTION) << "Invalid branch id for switch actor:" << GetAID() << " branch id:" << branch_id; + } + return branch_id_to_index_[branch_id]; + } + DeviceTensor *device_tensor = input_device_tensors_[0]; + if (device_tensor == nullptr) { + MS_LOG(EXCEPTION) << "Index of switch actor is empty:" << GetAID(); + } auto inputs = node_->inputs(); TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0); size_t size = abstract::TypeIdSize(type_id); @@ -219,28 +305,46 @@ size_t SwitchActor::GetIndex() { bool SwitchActor::CheckLaunchCondition(OpContext *context) const { MS_EXCEPTION_IF_NULL(context); if (input_datas_num_ != 0) { - auto data_iter = input_op_datas_.find(context->sequential_num_); - if (data_iter == input_op_datas_.end()) { + auto data_iter = input_data_.find(context->sequential_num_); + if (data_iter == input_data_.end()) { return false; } if (data_iter->second.size() != input_datas_num_) { return false; } + if (std::any_of(data_iter->second.begin(), data_iter->second.end(), + [](const auto &input_stack) { return input_stack.second.empty(); })) { + return false; + } } + + if (input_controls_num_ != 0) { + auto data_iter = input_controls_.find(context->sequential_num_); + if (data_iter == input_controls_.end()) { + return false; + } + if (data_iter->second.size() != input_controls_num_) { + return false; + } + if (std::any_of(data_iter->second.begin(), data_iter->second.end(), + [](const auto &input_stack) { return input_stack.second == 0; })) { + return false; + } + } + return true; } void SwitchActor::FetchInputDeviceTensor(OpContext *context) { MS_EXCEPTION_IF_NULL(context); input_device_tensors_.resize(input_nodes_.size()); - auto data_iter = input_op_datas_.find(context->sequential_num_); - if (data_iter != input_op_datas_.end()) { + auto data_iter = input_data_.find(context->sequential_num_); + if (data_iter != input_data_.end()) { for (auto &input_data : data_iter->second) { - MS_EXCEPTION_IF_NULL(input_data); - input_device_tensors_[input_data->index_] = input_data->data_; + input_device_tensors_[input_data.first] = input_data.second.top(); + input_data.second.pop(); } } - data_iter->second.clear(); for (const auto &device_tensor_store_key : device_tensor_store_keys_) { auto device_tensor = @@ -253,15 +357,28 @@ void SwitchActor::FetchInputDeviceTensor(OpContext *context) { } input_device_tensors_[device_tensor_store_key.first] = device_tensor; } + + auto control_iter = input_controls_.find(context->sequential_num_); + if (control_iter != input_controls_.end()) { + for_each(control_iter->second.begin(), control_iter->second.end(), + [](auto &input_control) { input_control.second--; }); + } } void SwitchActor::SendOutput(OpContext *context) { MS_EXCEPTION_IF_NULL(context); - auto index = GetIndex(); + auto index = GetIndex(context); if (index >= output_branch_arrows_.size()) { MS_LOG(EXCEPTION) << "Switch actor invalid index:" << index; } + if (local_branch_id_ >= 0) { + const auto &branch_arrows = output_branch_branch_arrows_[index]; + for (const auto &branch_arrow : branch_arrows) { + Async(branch_arrow, &GatherActor::CollectBranchId, local_branch_id_, context); + } + } + auto &output_branch_arrow = output_branch_arrows_[index]; auto &output_data = output_data_[index]; for (size_t i = 0; i < output_branch_arrow.size(); ++i) { @@ -270,6 +387,7 @@ void SwitchActor::SendOutput(OpContext *context) { MS_EXCEPTION_IF_NULL(data_arrow); MS_EXCEPTION_IF_NULL(data); data->data_ = input_device_tensors_[data_arrow->from_output_index_]; + Async(data_arrow->to_op_id_, &OpActor::RunOpData, data.get(), context); } @@ -279,7 +397,7 @@ void SwitchActor::SendOutput(OpContext *context) { auto &result_arrow = output_branch_result_arrow[i]; MS_EXCEPTION_IF_NULL(result_arrow); size_t from_index = result_arrow->from_output_index_; - for (const auto &backend_node : front_to_backend_parameter_[from_index]) { + for (const auto &backend_node : backend_parameters_[from_index]) { if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second).get() == input_device_tensors_[from_index]) { Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, backend_node.second, @@ -298,8 +416,10 @@ void SwitchActor::SendOutput(OpContext *context) { void SwitchActor::EraseInput(OpContext *context) { MS_EXCEPTION_IF_NULL(context); - if (input_datas_num_ != 0) { - auto ret = input_op_datas_.erase(context->sequential_num_); + auto data_iter = input_data_.find(context->sequential_num_); + if (data_iter != input_data_.end() && std::all_of(data_iter->second.begin(), data_iter->second.end(), + [](const auto &input_data) { return input_data.second.empty(); })) { + auto ret = input_data_.erase(context->sequential_num_); if (ret == 0) { std::string error_info = "Erase input data failed: " + GetAID().Name(); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); @@ -307,10 +427,15 @@ void SwitchActor::EraseInput(OpContext *context) { } if (input_controls_num_ != 0) { - auto ret = input_op_controls_.erase(context->sequential_num_); - if (ret == 0) { - std::string error_info = "Erase input controls failed: " + GetAID().Name(); - SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); + auto control_iter = input_controls_.find(context->sequential_num_); + if (control_iter != input_controls_.end() && + std::all_of(control_iter->second.begin(), control_iter->second.end(), + [](const auto &input_control) { return input_control.second == 0; })) { + auto ret = input_controls_.erase(context->sequential_num_); + if (ret == 0) { + std::string error_info = "Erase input control failed: " + GetAID().Name(); + SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); + } } } } @@ -319,37 +444,20 @@ void SwitchActor::SendMemoryFreeReq(OpContext *context) { Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context); } -void SwitchActor::FetchInputNode(const std::vector &origin_parameters_order, - const FrontToBackendNodeWithContext &front_to_backend_parameters, - const std::unordered_map &front_to_backend_kernel) { - front_to_backend_parameter_.resize(input_nodes_.size()); - +void SwitchActor::FetchInputNode(const ControlNodeParserPtr &parser) { for (size_t i = 0; i < input_nodes_.size(); ++i) { - const auto &input_node = input_nodes_[i]; - if (input_node->isa()) { - front_to_backend_parameter_[i].push_back({input_node, 0}); - } else if (input_node->isa()) { - if (front_to_backend_parameters.find(input_node) != front_to_backend_parameters.end()) { - const auto backend_node = front_to_backend_parameters.at(input_node).first; - front_to_backend_parameter_[i].push_back({backend_node, 0}); - } - } else if (input_node->isa()) { - if (IsCallNode(input_node)) { - const auto func_graphs = FetchFuncGraphbyCallNode(input_node->cast()); - for (const auto func_graph : func_graphs) { - if (func_graph->output()->isa()) { - front_to_backend_parameter_[i].push_back({func_graph->output(), 0}); - } - } - } else { - const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0); - if (front_to_backend_kernel.find(input_node) != front_to_backend_kernel.end()) { - front_to_backend_parameter_[i].emplace_back(kernel_with_index); - } - } + const auto &input_node = input_nodes_[i].first; + if (!HasAbstractRef(input_node)) { + backend_parameters_[i] = parser->FetchBackendInputNodeByFrontNode(input_node); + continue; } + + const auto &backend_weight = parser->FetchBackendNodebyWeightNode(input_node); + if (backend_weight == nullptr) { + MS_LOG(EXCEPTION) << "Cannot find backend node for weight node:" << AnfAlgo::GetNodeDebugString(input_node); + } + backend_parameters_[i].push_back({backend_weight, 0}); } } - } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/actor/switch_actor.h b/mindspore/ccsrc/runtime/framework/actor/switch_actor.h index 046a0c046fc..99a9ce7cfd7 100644 --- a/mindspore/ccsrc/runtime/framework/actor/switch_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/switch_actor.h @@ -58,16 +58,25 @@ constexpr size_t kMakeTupleInputStartPos = 1; // 5. Free Memory class SwitchActor : public SwitchActorBase { public: - SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node) - : SwitchActorBase(name), device_context_(device_context), node_(node) {} + SwitchActor(const std::string &name, DeviceContext *device_context, const CNodePtr &node, const int branch_id, + const bool need_branch_id_input) + : SwitchActorBase(name), + device_context_(device_context), + node_(node), + local_branch_id_(branch_id), + need_branch_id_input_(need_branch_id_input) {} ~SwitchActor() override = default; void Init() override; // The switch actor run when receive the input data. void RunOpData(OpData *input_data, OpContext *context); + // The switch actor run when receive the input control. + void RunOpControl(AID *input_control, OpContext *context); + // The switch actor run when receive the input branch id. + void CollectBranchId(const int branch_id, OpContext *context); // Initialize the input and output information of the switch actor According to node_. - void Initialize(); + void Initialize(const ControlNodeParserPtr &parser); // Add input for all branches. void AddCommonInput(const AnfNodePtr &node); // Fetch the input position of the data node. @@ -79,11 +88,16 @@ class SwitchActor : public SwitchActorBase { void InitPartial(const AnfNodePtr &node, const size_t branch_id); void InitSwitch(); void InitSwitchLayer(); - + // In control flow, the output of each subgraph is connected to a switch actor, and the switch actor is + // initialized with the return node of the subgraph. + void InitReturn(const ControlNodeParserPtr &parser); + // Initialize the size of the vector members. + void InitVectorSize(const size_t num); // Get index from DeviceTensor. - size_t GetIndex(); + size_t GetIndex(OpContext *context); // Add input for the branch. void AddInput(const AnfNodePtr &node, size_t branch); + void AddInput(const KernelWithIndex node_with_index, const size_t branch); // Check whether satisfy the condition for send outputs. bool CheckLaunchCondition(OpContext *context) const; @@ -95,12 +109,10 @@ class SwitchActor : public SwitchActorBase { void SendMemoryFreeReq(OpContext *context); // Collect all the backend inputs of switch actor. - void FetchInputNode(const std::vector &origin_parameters_order, - const FrontToBackendNodeWithContext &front_to_backend_parameters, - const std::unordered_map &front_to_backend_kernel); - // All inputs of the switch actor, excluding weight and tensor. + void FetchInputNode(const ControlNodeParserPtr &parser); + // All inputs of the switch actor, include weight and tensor. // Used to receive input data, the first input is the condition of switch. - std::vector input_nodes_; + std::vector input_nodes_; // The position of the branch output in the input_nodes_. std::vector> branch_inputs_pos_; @@ -126,9 +138,9 @@ class SwitchActor : public SwitchActorBase { // When the output is a value node from switch actor, the actor needs to send the anfnode to the output actor, // so all the nodes that may send the device tensor to switch actor are recorded. - std::vector> front_to_backend_parameter_; std::vector> backend_parameters_; std::vector> branch_total_inputs_; + std::vector branch_func_graph_; std::unordered_map branch_id_to_index_; @@ -148,8 +160,12 @@ class SwitchActor : public SwitchActorBase { // The dependent input controls number. size_t input_controls_num_{0}; CNodePtr node_; + + // The branch id corresponding to the funcgraph to which the gather actor belongs. int local_branch_id_; - size_t input_branch_id_num_; + // Whether it needs to accept the branch id. When the switch actor is the output of the subgraph, it needs to receive + // branch id sent by the gather actor of subgraph, which will be true at this time. + bool need_branch_id_input_; // The output_data_ corresponds to the output_data_arrows_ one by one. std::vector>> output_data_; diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index b0690b840d0..05628c13fbf 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -547,6 +547,8 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node) { void ControlNodeParser::Parse(const std::vector &control_nodes, const std::vector &graphs, const std::vector &device_contexts, const FuncGraphPtr &root_graph) { + root_func_graph_ = root_graph; + root_graph_parameters_ = root_graph->parameters(); CreateBranchIDForFuncGraph(control_nodes); @@ -598,6 +600,27 @@ bool ControlNodeParser::IsCallInputKernelGraph(const KernelGraphPtr &graph) { return true; } +bool ControlNodeParser::IsKernelInRootFuncGraph(const AnfNodePtr &kernel) { + if (kernel == nullptr) { + return true; + } + + const auto &graph = kernel->func_graph(); + if (kernel != nullptr && graph != nullptr) { + const auto &kernel_graph = dynamic_cast(graph.get()); + if (kernel_graph == nullptr) { + return true; + } + + const auto func_graph = kernel_graph->GetFuncGraph(); + if (func_graph != nullptr && func_graph != root_func_graph_) { + return false; + } + } + + return true; +} + size_t ControlNodeParser::GetCallNumByFuncGraph(const FuncGraphPtr &func_graph) { if (func_graph_to_call_num_.find(func_graph) == func_graph_to_call_num_.end()) { MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString(); @@ -622,6 +645,21 @@ DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePt return nullptr; } +AnfNodePtr ControlNodeParser::FetchBackendNodebyWeightNode(const AnfNodePtr &node) { + for (const auto &host_parameter_to_weight : host_parameter_to_weights_) { + for (const auto &front_weight : host_parameter_to_weight.second) { + if (front_weight == node) { + const auto &iter = front_to_backend_parameters_.find(front_weight); + if (iter != front_to_backend_parameters_.end()) { + return iter->second.first; + } + } + } + } + + return nullptr; +} + void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node, std::vector *value_nodes) { const auto &cnode = switch_node->cast(); @@ -928,6 +966,40 @@ std::vector FetchInputParameterbyControlNode(const AnfNodePtr &node, return parameters; } +std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph) { + std::vector parameters; + const auto &graph_parameters = graph->input_nodes(); + + for (const auto &graph_parameter : graph_parameters) { + const auto &front_node = graph->GetFrontAnfByBackendAnf(graph_parameter); + if (front_node != nullptr) { + parameters.emplace_back(front_node); + continue; + } + const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(graph_parameter); + if (front_node_with_index.first == nullptr) { + MS_LOG(WARNING) << "Invalid parameter of kernel graph, parameter :" + << AnfAlgo::GetNodeDebugString(graph_parameter); + continue; + } + + if (HasAbstractRef(AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, 0).first) || + HasAbstractMonad(front_node_with_index.first)) { + continue; + } + + if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimMakeTuple)) { + const auto &sub_parameters = FetchInputsByMakeTuple(front_node_with_index.first); + parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); + continue; + } + + parameters.emplace_back(front_node_with_index.first); + } + + return parameters; +} + void ControlNodeParser::FetchFrontToBackendParameter(const std::vector &graphs, const std::vector &device_contexts, const std::vector &control_nodes) { @@ -939,7 +1011,7 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vectorparameters()) { + for (const auto ¶meter : graph->input_nodes()) { auto front_node = graph->GetFrontAnfByBackendAnf(parameter); if (front_node != nullptr && front_node->isa() && diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 900f4218d59..5f824c132a8 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -71,8 +71,8 @@ FuncGraphPtr GetFuncgraphByBackendNode(const AnfNodePtr &backend_node); // Find all funcgraphs that the call node will call. std::vector FetchFuncGraphbyCallNode(const AnfNodePtr &node); -// Recursive interface, get all input of make tuple node. -std::vector FetchInputsByMakeTuple(const AnfNodePtr &node); +// Get parameters in kernel graph. +std::vector FetchParameterbyKernelGraph(const KernelGraphPtr &graph); // ControlNodeParser is used to parse control nodes, and get the edges between nodes. class ControlNodeParser { @@ -107,6 +107,15 @@ class ControlNodeParser { // Check whether there is a call node in the front input nodes of the kernel graph. bool IsCallInputKernelGraph(const KernelGraphPtr &graph); + // Check whether the kernel actor belongs to the root graph. + // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output should + // be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be sent to the + // output switch actor. + bool IsKernelInRootFuncGraph(const AnfNodePtr &kernel); + + // Get the backend node corresponding to the weight node in the subgraph. + AnfNodePtr FetchBackendNodebyWeightNode(const AnfNodePtr &node); + private: friend class GraphScheduler; @@ -195,10 +204,12 @@ class ControlNodeParser { std::vector control_node_parameters_; // The number of calls to func_graph. std::unordered_map func_graph_to_call_num_; - // The kernel graph of call exists in the front-end input node. - std::unordered_map call_input_kernel_graphs_; + // The kernel graph of call exists in the front input node. // In the scene of funcgrarph recursive call, general input and call input are passed recursively, so a gather actor // is created for kernel graph which has a call input. + std::unordered_map call_input_kernel_graphs_; + // Root funcgraph and its parameters. + FuncGraphPtr root_func_graph_; std::vector root_graph_parameters_; }; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index a1ebf1e3a3e..0a6fe2ca7c6 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -762,7 +762,8 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_); // Link the control arrows of loop count actor, which depends on the no input kernel actors. - LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set); + LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set, + graph_compiler_info.control_node_parser_); // Link the output result arrows for output actors. LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info); @@ -1000,28 +1001,58 @@ std::vector GraphScheduler::BuildSwitchActor(const GraphCompiler front_to_backend_kernel[pair.first] = pair.second->kernel_; } + // Build switch actor by switch node and switchlayer node. for (const auto &control_node : graph_compiler_info.control_nodes_) { if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) || AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { - auto actor_name = control_node->fullname_with_scope(); + const auto func_graph = control_node->func_graph(); + const auto branch_id = graph_compiler_info.control_node_parser_->GetBranchIDByFuncGraph(func_graph); + const auto &actor_name = control_node->DebugString(); auto switch_actor = std::make_shared(actor_name, graph_compiler_info.device_contexts_[0], - control_node->cast()); - switch_actor->Initialize(); + control_node->cast(), branch_id, false); + switch_actor->Initialize(graph_compiler_info.control_node_parser_); // Fetch all the input nodes of switch actor. - switch_actor->FetchInputNode(graph_compiler_info.origin_parameters_order_, - graph_compiler_info.control_node_parser_->front_to_backend_parameters_, - front_to_backend_kernel); + switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_); InsertActor(switch_actor.get()); switch_actors.emplace_back(switch_actor); } } + + // Build switch actor by return node. + const auto func_graphs_to_call_num = graph_compiler_info.control_node_parser_->func_graph_to_call_num_; + for (const auto &func_graph_to_call_num : func_graphs_to_call_num) { + const auto &return_node = func_graph_to_call_num.first->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + const auto &actor_name = return_node->DebugString(); + auto switch_actor = std::make_shared(actor_name, graph_compiler_info.device_contexts_[0], + return_node->cast(), kInvalidBranchID, true); + switch_actor->Initialize(graph_compiler_info.control_node_parser_); + + // Fetch all the input nodes of switch actor. + switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_); + InsertActor(switch_actor.get()); + switch_actors.emplace_back(switch_actor); + } + return switch_actors; } std::vector GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) { std::vector gather_actors; + const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor"; + const auto &loop_count_actor = FetchActor(loop_count_actor_name); + if (loop_count_actor == nullptr) { + return gather_actors; + } + + const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor"; + const auto &output_actor = FetchActor(output_actor_name); + MS_EXCEPTION_IF_NULL(output_actor); + + const auto parser = graph_compiler_info.control_node_parser_; + bool is_main_return = true; // Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor. std::unordered_map front_to_backend_kernel; @@ -1030,44 +1061,84 @@ std::vector GraphScheduler::BuildGatherActor(const GraphCompiler } for (const auto &control_node : graph_compiler_info.control_nodes_) { - // Root funcgraph does not need to create a gather actor. + const auto &func_graph = control_node->func_graph(); + const auto &cnode = control_node->cast(); + const auto &inputs = cnode->inputs(); + const auto &return_node = func_graph->get_return(); + const auto &output_switch_aid = FetchActor(return_node->DebugString())->GetAID(); + if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { + // Root funcgraph does not need to create a gather actor. if (is_main_return) { is_main_return = false; continue; } - const auto &cnode = control_node->cast(); - const auto inputs = cnode->inputs(); // If the output of funcgraph is a value node, no need to create gather actor. if (inputs[kReturnInputPos]->isa()) { continue; } - auto func_graph = control_node->func_graph(); auto actor_name = func_graph->ToString(); std::vector parameters; for (const auto ¶meter : func_graph->get_inputs()) { - if (!HasAbstractMonad(parameter)) { - parameters.emplace_back(parameter); + if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) { + continue; } + parameters.emplace_back(parameter); } - const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor"; - const auto &loop_count_actor = FetchActor(loop_count_actor_name); - MS_EXCEPTION_IF_NULL(loop_count_actor); - const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor"; - const auto &output_actor = FetchActor(output_actor_name); - MS_EXCEPTION_IF_NULL(output_actor); + const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph); auto gather_actor = - std::make_shared(actor_name, parameters, loop_count_actor->GetAID(), output_actor->GetAID()); + std::make_shared(actor_name, parameters, true, output_switch_aid, AID(), branch_id); gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_); InsertActor(gather_actor.get()); gather_actors.emplace_back(gather_actor); } } + // Create gather actor for call node which input0 of call node is a funcgraph. + for (const auto &control_node : graph_compiler_info.control_nodes_) { + const auto &cnode = control_node->cast(); + const auto &inputs = cnode->inputs(); + + if (inputs[0]->isa() && IsValueNode(inputs[0])) { + // Collect the parameters. + std::vector parameters; + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + if (HasAbstractMonad(inputs[i]) || HasAbstractRef(inputs[i])) { + continue; + } + parameters.emplace_back(inputs[i]); + } + + auto func_graph = control_node->func_graph(); + auto actor_name = control_node->DebugString(); + const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph); + const auto &to_func_graph = GetValueNode(inputs[0]); + const auto &to_actor = FetchActor(to_func_graph->ToString()); + auto gather_actor = + std::make_shared(actor_name, parameters, false, AID(), to_actor->GetAID(), branch_id); + gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_); + + InsertActor(gather_actor.get()); + gather_actors.emplace_back(gather_actor); + } + } + + // Create gather actor for kernel graph which has a call input. + const auto &graph_with_device_contexts = graph_compiler_info.control_node_parser_->call_input_kernel_graphs_; + for (const auto &graph_with_device_context : graph_with_device_contexts) { + const auto &graph = graph_with_device_context.first; + const auto ¶meters = FetchParameterbyKernelGraph(graph); + + auto actor_name = graph->ToString(); + auto gather_actor = std::make_shared(actor_name, parameters, false, AID(), AID(), kInvalidBranchID); + InsertActor(gather_actor.get()); + gather_actors.emplace_back(gather_actor); + } + return gather_actors; } @@ -1079,6 +1150,18 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf auto from_kernel = from_kernel_with_output_idx.first; auto front_node = GetFrontNodeByBackendNode(from_kernel); + + if (from_kernel->isa() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) { + // When there is a call input in the kernel graph, all the inputs of the kernel graph needs to be sent by gather. + const auto actor_name = graph->ToString(); + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + const auto &real_front_node = GetFrontNodeByKernelGraph(from_kernel, graph); + LinkDataArrowForGatherActor(dynamic_cast(actor), real_front_node, to_actor, + to_kernel_with_input_idx.second); + return; + } + if (IsDeviceQueueDSActor(from_kernel)) { // Link the data arrows of device queue data source actor. std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); @@ -1092,7 +1175,12 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf } auto actor_name = func_graph->ToString(); const auto &from_actor = dynamic_cast(FetchActor(actor_name)); - LinkDataArrowForGatherActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx); + if (HasAbstractRef(from_kernel)) { + to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node.get()); + return; + } + + LinkDataArrowForGatherActor(from_actor, front_node, to_actor, to_kernel_with_input_idx.second); } else if (IsHostQueueDSActor(from_kernel, graph, tensor, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_)) { // Link the data arrows of host queue data source actor. @@ -1135,9 +1223,10 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get()); return; } + if (graph_output_to_actor_.count(front_output_with_index) == 0 && (!IsSwitchActor(front_output_node))) { - MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_node->fullname_with_scope() - << ", internal parameter:" << internal_parameter->fullname_with_scope(); + MS_LOG(EXCEPTION) << "Can't find actor by front node:" << AnfAlgo::GetNodeDebugString(front_output_node) + << ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter); } auto actor_pair = graph_output_to_actor_[front_output_with_index]; MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString() @@ -1151,11 +1240,12 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second); LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx); } else if (IsSwitchActor(front_output_node)) { - const auto &actor_name = front_output_node->fullname_with_scope(); + const auto &actor_name = front_output_node->DebugString(); const auto &actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); - LinkDataArrowForSwitchActor(switch_actor, to_actor, to_kernel_with_input_idx.second); + LinkDataArrowForSwitchActor(switch_actor, 0, to_actor, to_kernel_with_input_idx.second); + to_actor->input_datas_num_++; } else if (IsKernelActor(front_output_node)) { auto from_actor = dynamic_cast(actor_pair.first); auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second); @@ -1494,7 +1584,8 @@ void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &g } } -void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set) { +void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set, + const ControlNodeParserPtr &parser) { MS_EXCEPTION_IF_NULL(actor_set); // There is no loop count actor in step mode. if (loop_count_actor == nullptr) { @@ -1504,7 +1595,9 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun // Collect the actors which have no output. std::vector no_output_actors; for (auto &kernel_actor : actor_set->kernel_actors_) { - if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) { + // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor. + if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) && + parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) { MS_EXCEPTION_IF_NULL(kernel_actor->kernel_); MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes."; no_output_actors.emplace_back(kernel_actor.get()); @@ -1523,7 +1616,7 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun // No output actor --> loop count actor. for (auto &no_output_actor : no_output_actors) { no_output_actor->output_control_arrows_.emplace_back(loop_count_actor->GetAID()); - loop_count_actor->branch_id_to_input_controls_num_[kMainBranchID]++; + loop_count_actor->input_controls_num_++; } // Loop count actor --> data source actor. @@ -1553,7 +1646,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, MS_EXCEPTION_IF_NULL(graph); ++number; auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output()); - std::set>> unique_output_positions; + std::set> unique_output_positions; std::set unique_outputs; for (const auto &output : outputs) { unique_outputs.insert(output); @@ -1571,15 +1664,14 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, continue; } unique_output_positions.insert(iter->second); - for (auto &output_position : iter->second.second) { + for (auto &output_position : iter->second) { to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[number - 1]; // The device tensor of graph out need be taken over by host tensor, so set the max reference count. UpdateRefCount(output_with_index.first, output_with_index.second, true); // The graph output is from device tensor store. if (IsPersistentDeviceTensor(output_with_index.first)) { - to_actor->device_tensor_store_keys_[iter->second.first].emplace_back(output_position, - output_with_index.first); + to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first); continue; } @@ -1631,11 +1723,45 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo const ActorSet *actor_set) { const auto &to_actor = actor_set->output_actor_; const auto &loop_count_actor = actor_set->loop_count_actor_; - const auto &switch_actors = actor_set->switch_actors_; if (to_actor == nullptr || loop_count_actor == nullptr) { return; } + // When there is a call node in the output, the output will be sent to the output actor by the switch actor of + // the funcgraph called by the call node. + const auto &outputs = graph_compiler_info.origin_outputs_order_; + for (const auto &output : outputs) { + const auto &output_node = output.first.first; + const auto &output_index = output.first.second; + const auto output_poses = output.second; + + if (IsCallNode(output_node)) { + const auto &func_graphs = FetchFuncGraphbyCallNode(output_node); + for (const auto func_graph : func_graphs) { + const auto &actor_name = func_graph->get_return()->DebugString(); + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + auto switch_actor = dynamic_cast(actor); + + // Set branch index into switch actor. + size_t branch_index = switch_actor->branch_id_to_index_.size(); + if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) { + branch_index = switch_actor->branch_id_to_index_[kMainBranchID]; + } else { + switch_actor->branch_id_to_index_[kMainBranchID] = branch_index; + } + + // Link output result arrow. + for (const auto output_pos : output_poses) { + auto op_arrow = std::make_shared(output_index, to_actor->GetAID(), output_pos); + to_actor->device_contexts_[output_pos] = switch_actor->device_context_; + switch_actor->output_branch_result_arrows_[branch_index].emplace_back(op_arrow); + } + } + } + } + + const auto &switch_actors = actor_set->switch_actors_; for (const auto &from_actor : switch_actors) { MS_EXCEPTION_IF_NULL(from_actor); auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0); @@ -1646,7 +1772,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo // If the switch actor is in the output list, the output of switch actor should be sent to the output actor. // And need to link a control arrow to the loop count actor. - for (const auto pos : iter->second.second) { + for (const auto pos : iter->second) { to_actor->device_contexts_[pos] = from_actor->device_context_; } @@ -1656,14 +1782,14 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID(); } - for (const auto pos : iter->second.second) { + for (const auto pos : iter->second) { auto op_arrow = std::make_shared(input_pos[0], to_actor->GetAID(), pos); from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow); } from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID()); } - loop_count_actor->branch_id_to_input_controls_num_[kMainBranchID]++; + loop_count_actor->input_controls_num_++; } } @@ -1709,7 +1835,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector(to_actor); MS_EXCEPTION_IF_NULL(real_to_actor); - real_to_actor->branch_id_to_input_controls_num_[kMainBranchID]++; + real_to_actor->input_controls_num_++; } else if (output_contorl.Name().find("copy_from") != string::npos) { auto real_to_actor = dynamic_cast(to_actor); MS_EXCEPTION_IF_NULL(real_to_actor); @@ -1730,188 +1856,247 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vectorcast(); + const auto &from_func_graph = node->func_graph(); + auto inputs = cnode->inputs(); + + // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor. + if (inputs[0]->isa()) { + auto actor = FetchActor(inputs[0]->DebugString()); + MS_EXCEPTION_IF_NULL(actor); + auto switch_actor = dynamic_cast(actor); + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + if (HasAbstractMonad(inputs[i])) { + continue; + } + switch_actor->AddCommonInput(inputs[i]); + } + } + } + + for (const auto &node : graph_compiler_info.control_nodes_) { + CNodePtr cnode = node->cast(); + const auto &from_func_graph = node->func_graph(); auto inputs = cnode->inputs(); // Link data arrow for switch node. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { - auto actor = actor_name_to_actor_[node->fullname_with_scope()]; + auto actor = actor_name_to_actor_[node->DebugString()]; LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast(actor)); } else if (inputs[0]->isa() && IsValueNode(inputs[0])) { // Link the data arrow for the input of the call node. - auto func_graph = GetValueNode(inputs[0]); - auto actor = actor_name_to_actor_[func_graph->ToString()]; - for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - LinkDataArrowByControlNode(graph_compiler_info, inputs[i], actor, i - kCallInputStartPos); - } - } else if (inputs[0]->isa()) { - // Link switch inputs which is on the call node. - if ((!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch)) && - (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer))) { - MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:" - << AnfAlgo::GetNodeDebugString(inputs[0]); - } + const auto &actor_name = node->DebugString(); + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + auto gather_actor = dynamic_cast(actor); - auto switch_op_actor = FetchActor(inputs[0]->fullname_with_scope()); - if (switch_op_actor == nullptr) { - MS_LOG(EXCEPTION) << "Cannot find actor of switch node:" << AnfAlgo::GetNodeDebugString(inputs[0]); - } - auto switch_actor = dynamic_cast(switch_op_actor); + const auto &func_graph = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(func_graph); + const auto &to_actor = FetchActor(func_graph->ToString()); + MS_EXCEPTION_IF_NULL(to_actor); + + size_t persist_input_num = 0; for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { - switch_actor->AddCommonInput(inputs[i]); - auto pos = switch_actor->FetchDataNodePosition(inputs[i]); - LinkDataArrowByControlNode(graph_compiler_info, inputs[i], switch_actor, pos - kCallInputStartPos); + MS_EXCEPTION_IF_NULL(actor); + if (inputs[i]->isa()) { + const auto &node_value = inputs[i]->cast()->value(); + if (!node_value->isa()) { + persist_input_num++; + continue; + } + + gather_actor->device_tensor_store_keys_.push_back( + {i - kCallInputStartPos - persist_input_num, inputs[i].get()}); + gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] = + graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]); + } else if ((inputs[i]->isa() && HasAbstractRef(inputs[i]->cast())) || + AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState) || HasAbstractMonad(inputs[i])) { + persist_input_num++; + continue; + } else { + const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(inputs[i], 0); + LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, actor, + i - kCallInputStartPos - persist_input_num); + } + + auto op_arrow = std::make_shared(i - kCallInputStartPos - persist_input_num, to_actor->GetAID(), + i - kCallInputStartPos - persist_input_num); + gather_actor->output_data_arrows_.emplace_back(op_arrow); } } } + // Link arrow for switch actor of subgraph output. + for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) { + const auto &func_graph = func_graph_with_call_num.first; + MS_EXCEPTION_IF_NULL(func_graph); + auto actor = FetchActor(func_graph->get_return()->DebugString()); + MS_EXCEPTION_IF_NULL(actor); + LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast(actor)); + } + + // Link arrow for gather actor for call input kernel graph. + for (const auto &call_input_kernel_graph : graph_compiler_info.control_node_parser_->call_input_kernel_graphs_) { + const auto &kernel_graph = call_input_kernel_graph.first; + MS_EXCEPTION_IF_NULL(kernel_graph); + auto actor = FetchActor(kernel_graph->ToString()); + MS_EXCEPTION_IF_NULL(actor); + auto gather_actor = dynamic_cast(actor); + + for (const auto &input_node : gather_actor->data_nodes_) { + const auto &from_func_graph = kernel_graph->GetFuncGraph(); + const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0); + LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, + gather_actor->FetchDataNodePosition(input_node)); + } + } + LinkBranchArrowForSwitchActor(graph_compiler_info, actor_set); + LinkBranchArrowForGatherActor(graph_compiler_info, actor_set); - LinkControlArrowForGatherActor(&(actor_set->gather_actors_), actor_set->loop_count_actor_.get(), - graph_compiler_info.graphs_); + LinkControlArrowForGatherActor(&(actor_set->gather_actors_), &(actor_set->kernel_actors_), + actor_set->loop_count_actor_.get(), graph_compiler_info.graphs_, + graph_compiler_info.control_node_parser_); - LinkOutputResultArrowForGatherActor(graph_compiler_info, actor_set); + LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(), + graph_compiler_info.origin_outputs_order_); LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set); } -void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor, - KernelWithIndex from_kernel_with_output_idx, - KernelWithIndex to_kernel_with_input_idx) { +void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, + KernelActor *to_actor, const size_t to_index) { MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(to_actor); - - auto from_kernel = from_kernel_with_output_idx.first; - MS_EXCEPTION_IF_NULL(from_kernel); - auto to_input_index = to_kernel_with_input_idx.second; - - auto front_node = GetFrontNodeByBackendNode(from_kernel); - if (front_node == nullptr) { - MS_LOG(EXCEPTION) << "Cannot find front node of node:" << AnfAlgo::GetNodeDebugString(from_kernel); - } + MS_EXCEPTION_IF_NULL(front_node); auto position = from_actor->FetchDataNodePosition(front_node); - auto to_aid = to_actor->GetAID(); - auto op_arrow = std::make_shared(position, to_aid, to_input_index); + auto op_arrow = std::make_shared(position, to_actor->GetAID(), to_index); from_actor->output_data_arrows_.emplace_back(op_arrow); to_actor->input_datas_num_++; } -void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node, +void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, + const ControlNodeParserPtr &parser, const FuncGraphPtr &from_func_graph, OpActor *to_actor, const size_t to_index) { // Fetch all the funcgraph that call node would call. - const auto cnode = call_node->cast(); + const auto cnode = call_node_with_index.first->cast(); std::vector func_graphs = FetchFuncGraphbyCallNode(cnode); // Collect the output of each funcgraph. for (const auto &func_graph : func_graphs) { - // The output of funcgraph can only have one. - auto outputs = AnfAlgo::GetAllOutputWithIndex(func_graph->output()); - if (outputs.size() != 1) { - MS_LOG(EXCEPTION) << "Output of func graph is more than one, func graph:" << func_graph->ToString(); - } - - auto output_with_index = outputs[0]; - if (IsKernelActor(output_with_index.first)) { - // Input is a kernel actor. - const auto &iter = front_node_to_actor_.find(output_with_index.first); - if (iter == front_node_to_actor_.end()) { - MS_LOG(EXCEPTION) << "Cannot find kernel actor of front node:" - << AnfAlgo::GetNodeDebugString(output_with_index.first); - } - auto from_actor = iter->second; - auto op_arrow = std::make_shared(output_with_index.second, to_actor->GetAID(), to_index); - from_actor->output_data_arrows_.emplace_back(op_arrow); - auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, output_with_index.second, false); - UpdateRefCount(device_tensor.get(), true); - } else if (output_with_index.first->isa()) { - // Input is a parameter from gather actor. - const auto &actor_name = func_graph->ToString(); - auto actor = FetchActor(actor_name); - MS_EXCEPTION_IF_NULL(actor); - auto gather_actor = dynamic_cast(actor); - MS_EXCEPTION_IF_NULL(gather_actor); - - const auto &iter = - find(gather_actor->data_nodes_.begin(), gather_actor->data_nodes_.end(), output_with_index.first); - if (iter == gather_actor->data_nodes_.end()) { - MS_LOG(EXCEPTION) << "Cannot find parameter:" << AnfAlgo::GetNodeDebugString(output_with_index.first) - << " in funcgraph"; - } - const size_t pos = iter - gather_actor->data_nodes_.begin(); - auto op_arrow = std::make_shared(pos, to_actor->GetAID(), to_index); - gather_actor->output_data_arrows_.emplace_back(op_arrow); - } else if (output_with_index.first->isa()) { - // If the output is a value node, then the value node needs to be sent by the switch actor. + if (func_graph->output()->isa()) { const auto &call_inputs = cnode->inputs(); if (AnfAlgo::CheckPrimitiveType(call_inputs[0], prim::kPrimSwitch)) { - const auto &actor_name = call_inputs[0]->fullname_with_scope(); + const auto &actor_name = call_inputs[0]->DebugString(); const auto &actor = FetchActor(actor_name); MS_EXCEPTION_IF_NULL(actor); auto switch_actor = dynamic_cast(actor); MS_EXCEPTION_IF_NULL(switch_actor); + const auto &output_with_index = KernelWithIndex(func_graph->output(), 0); + const auto &iter = + find(switch_actor->input_nodes_.begin(), switch_actor->input_nodes_.end(), output_with_index); + if (iter == switch_actor->input_nodes_.end()) { + MS_LOG(EXCEPTION) << "Invalid input node for switch actor:" << switch_actor->GetAID() + << " node:" << AnfAlgo::GetNodeDebugString(func_graph->output()); + } + size_t pos = iter - switch_actor->input_nodes_.begin(); // Add output for each branch of switch. for (size_t i = 0; i < switch_actor->branch_inputs_pos_.size(); ++i) { - if (switch_actor->branch_inputs_pos_[i].empty()) { - MS_LOG(EXCEPTION) << "No input for switch actor:" << actor_name << " branch:" << i; + const auto poses = switch_actor->branch_inputs_pos_[i]; + if (find(poses.begin(), poses.end(), pos) == poses.end()) { + continue; } - const auto from_index = switch_actor->branch_inputs_pos_[i][0]; - auto op_arrow = std::make_shared(from_index, to_actor->GetAID(), to_index); + auto op_arrow = std::make_shared(pos, to_actor->GetAID(), to_index); switch_actor->output_branch_arrows_[i].emplace_back(op_arrow); } } else { - MS_LOG(EXCEPTION) << "Invalid input for call node:" << AnfAlgo::GetNodeDebugString(call_node); + MS_LOG(EXCEPTION) << "Invalid funcgraph:" << func_graph->ToString(); } - } else { - MS_LOG(EXCEPTION) << "Output of func graph is not a parameter or kernel, func graph:" << func_graph->ToString() - << " output node:" << AnfAlgo::GetNodeDebugString(output_with_index.first); + continue; } + + const auto actor_name = func_graph->get_return()->DebugString(); + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + auto switch_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(switch_actor); + const size_t branch_index = switch_actor->branch_id_to_index_.size(); + + const auto &func_graph_to_branch_id = parser->func_graph_to_branch_id_; + const auto &iter = func_graph_to_branch_id.find(from_func_graph); + + int branch_id = kMainBranchID; + if (iter != func_graph_to_branch_id.end()) { + branch_id = iter->second; + } + if (switch_actor->branch_id_to_index_.find(branch_id) != switch_actor->branch_id_to_index_.end()) { + LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, + switch_actor->branch_id_to_index_[branch_id]); + continue; + } + LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, branch_index); + switch_actor->branch_id_to_index_[branch_id] = branch_index; } } -void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, - const size_t to_index) { +void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, + OpActor *to_actor, const size_t to_index, + const size_t branch_index) { MS_EXCEPTION_IF_NULL(from_actor); - - for (size_t i = 0; i < from_actor->output_branch_arrows_.size(); ++i) { - if (from_actor->branch_inputs_pos_[i].empty()) { + MS_EXCEPTION_IF_NULL(to_actor); + size_t start_branch = 0; + size_t max_branch = from_actor->output_branch_arrows_.size(); + if (branch_index != SIZE_MAX) { + start_branch = branch_index; + max_branch = branch_index + 1; + } + for (size_t i = start_branch; i < max_branch; ++i) { + if (from_actor->branch_inputs_pos_[i].size() <= from_index) { MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i; } - const auto from_index = from_actor->branch_inputs_pos_[i][0]; - auto op_arrow = std::make_shared(from_index, to_actor->GetAID(), to_index); + auto op_arrow = + std::make_shared(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index); from_actor->output_branch_arrows_[i].emplace_back(op_arrow); } - to_actor->input_datas_num_++; } void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, - const AnfNodePtr &input_node, OpActor *to_actor, + const KernelWithIndex &input_with_index, + const FuncGraphPtr &from_func_graph, OpActor *to_actor, const size_t to_index) { const auto ¶meters = graph_compiler_info.origin_parameters_order_; const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_; + const auto &input_node = input_with_index.first; if (IsCallNode(input_node)) { // The actor input is a call node. - LinkDataArrowByCallInput(graph_compiler_info, input_node, to_actor, to_index); + LinkDataArrowByCallInput(input_with_index, graph_compiler_info.control_node_parser_, from_func_graph, to_actor, + to_index); } else if (IsGatherActor(input_node, actor_name_to_actor_)) { // The actor input is a parameter in gather actor. auto from_actor = dynamic_cast(actor_name_to_actor_[input_node->func_graph()->ToString()]); auto position = from_actor->FetchDataNodePosition(input_node); auto op_arrow = std::make_shared(position, to_actor->GetAID(), to_index); from_actor->output_data_arrows_.emplace_back(op_arrow); + } else if (IsSwitchActor(input_node)) { + const auto &actor_name = input_node->DebugString(); + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + LinkDataArrowForSwitchActor(dynamic_cast(actor), 0, to_actor, to_index); } else if (IsKernelActor(input_node)) { // The actor input is a cnode. - auto input_witch_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0); - if (front_node_to_actor_.find(input_witch_index.first) == front_node_to_actor_.end()) { - MS_LOG(EXCEPTION) << "Cannot find switch actor input_node:" << AnfAlgo::GetNodeDebugString(input_node); + if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) { + MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID() + << " input_node:" << AnfAlgo::GetNodeDebugString(input_node); } - auto op_arrow = std::make_shared(input_witch_index.second, to_actor->GetAID(), to_index); - auto from_actor = front_node_to_actor_[input_witch_index.first]; + auto op_arrow = std::make_shared(input_with_index.second, to_actor->GetAID(), to_index); + auto from_actor = front_node_to_actor_[input_node]; from_actor->output_data_arrows_.emplace_back(op_arrow); - auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_witch_index.second, false); + auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_with_index.second, false); UpdateRefCount(device_tensor.get(), true); } else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) { // The actor input is a parameter in host data source actor. @@ -1948,10 +2133,12 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_ const auto &inputs = actor->input_nodes_; for (size_t i = 0; i < inputs.size(); ++i) { auto input = inputs[i]; - if (input->isa()) { + if (input.first->isa() || HasAbstractRef(input.first)) { continue; } - LinkDataArrowByControlNode(graph_compiler_info, input, actor, i); + + const FuncGraphPtr from_func_graph = actor->node_->func_graph(); + LinkDataArrowByControlNode(graph_compiler_info, input, from_func_graph, actor, i); } // Link switch output. @@ -1975,8 +2162,10 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_ } } -void GraphScheduler::LinkControlArrowForGatherActor(std::vector *from_actors, LoopCountActor *to_actor, - const std::vector &graphs) { +void GraphScheduler::LinkControlArrowForGatherActor(std::vector *from_actors, + std::vector *kernel_actors, + LoopCountActor *to_actor, const std::vector &graphs, + const ControlNodeParserPtr &parser) { if (from_actors == nullptr || to_actor == nullptr) { return; } @@ -2010,15 +2199,30 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector } } - // link control arrow to loop count actor. - for (auto &from_actor : *from_actors) { - MS_EXCEPTION_IF_NULL(from_actor); + for (auto &kernel_actor : *kernel_actors) { + MS_EXCEPTION_IF_NULL(kernel_actor); - // If the gather actor has no output, then adds the output control to loop count actor. - if (from_actor->output_data_arrows_.size() == 0 && from_actor->output_control_arrows_.size() == 0) { - auto to_aid = to_actor->GetAID(); - from_actor->output_control_arrows_.emplace_back(to_aid); - to_actor->branch_id_to_input_controls_num_[kMainBranchID]++; + if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) && + !parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) { + // Check whether the kernel actor belongs to the root graph. + // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output + // should be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be + // sent to the output switch actor. + const auto &graph = kernel_actor->kernel_->func_graph(); + OpActor *actor = nullptr; + + if (graph != nullptr) { + const auto &kernel_graph = dynamic_cast(graph.get()); + const auto func_graph = kernel_graph->GetFuncGraph(); + if (func_graph != nullptr) { + actor = FetchActor(func_graph->get_return()->DebugString()); + if (actor != nullptr) { + auto switch_actor = dynamic_cast(actor); + kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID()); + switch_actor->input_controls_num_++; + } + } + } } } } @@ -2030,6 +2234,8 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector return; } + // If there is no output from the switch actor branch, it means that the subgraph has no input, + // and need to connect a control arrow to the corresponding gather actor. for (auto &switch_actor : (*switch_actors)) { for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) { const auto &arrows = switch_actor->output_branch_arrows_[i]; @@ -2045,17 +2251,18 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector } } + // Collect all the call node in outputs. std::set call_nodes; for (const auto &output : origin_outputs_order) { if (IsCallNode(output.first.first)) { call_nodes.insert(output.first.first); } } + to_actor->input_controls_num_ += call_nodes.size(); - to_actor->branch_id_to_input_controls_num_[kMainBranchID] += call_nodes.size(); - + // Link the output switch actor of the subgraph to the output actor. for (const auto &call_node : call_nodes) { - const auto &func_graphs = FetchFuncGraphbyCallNode(call_node->cast()); + const auto &func_graphs = FetchFuncGraphbyCallNode(call_node); for (const auto func_graph : func_graphs) { MS_EXCEPTION_IF_NULL(func_graph); const auto &actor_name = func_graph->get_return()->DebugString(); @@ -2104,87 +2311,26 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap return; } - const auto func_graph = graph_compiler_info.control_nodes_[0]->func_graph(); - const auto &loop_count_actor = actor_set->loop_count_actor_.get(); - const auto &output_actor = actor_set->output_actor_.get(); - - // If there is only one branch output, set the branch id of the loop count to 0, no need to send the branch id. - auto outputs = graph_compiler_info.control_node_parser_->front_output_nodes_; - if (outputs.size() == 1) { - return; + // Link branch arrow from gather actor to gather actor. + for (const auto &control_node : graph_compiler_info.control_nodes_) { + const auto &cnode = control_node->cast(); + const auto &inputs = cnode->inputs(); + if (inputs[0]->isa() && IsValueNode(inputs[0])) { + const auto &actor_name = control_node->DebugString(); + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + auto gather_actor = dynamic_cast(actor); + gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_); + } } - loop_count_actor->branch_id_ = kInvalidBranchID; - output_actor->branch_id_ = kInvalidBranchID; - - std::vector output_func_graphs; - for_each(outputs.begin(), outputs.end(), - [&output_func_graphs](const AnfNodePtr &output) { output_func_graphs.push_back(output->func_graph()); }); - int func_graph_num = SizeToInt(output_func_graphs.size()); - std::unordered_map graph_to_control_num; - - // Count the control arrow num of gather actor. - for (int i = 0; i < func_graph_num; ++i) { - auto output_func_graph = output_func_graphs[i]; - auto actor_name = output_func_graph->ToString(); + // Link branch arrow from gather actor to switch actor. + for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) { + const auto &actor_name = func_graph_with_call_num.first->ToString(); auto actor = FetchActor(actor_name); - if (actor == nullptr) { - continue; - } - const auto &from_actor = dynamic_cast(actor); - MS_EXCEPTION_IF_NULL(from_actor); - - from_actor->branch_id_ = i; - graph_to_control_num[output_func_graph] = 0; - if ((from_actor->output_data_arrows_.size() == 0) && (from_actor->output_control_arrows_.size() == 0)) { - graph_to_control_num[output_func_graph]++; - } - } - - // Count the control arrow num of kernel actor. - for (const auto &kernel_actor : actor_set->kernel_actors_) { - MS_EXCEPTION_IF_NULL(kernel_actor); - if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) { - MS_EXCEPTION_IF_NULL(kernel_actor->kernel_); - const auto &sub_func_graph = FetchFuncGraphByNode(kernel_actor->kernel_); - if (sub_func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Cannot get funcgraph from kernel:" << kernel_actor->kernel_->fullname_with_scope(); - } - - if (graph_to_control_num.find(sub_func_graph) != graph_to_control_num.end()) { - graph_to_control_num[sub_func_graph]++; - } else { - for (auto &pair : graph_to_control_num) { - pair.second++; - } - } - } - } - - for (size_t i = 0; i < graph_to_control_num.size(); ++i) { - // Branch id starts from 1. - auto branch_id = SizeToInt(i) + kSubBranchStartID; - auto sub_func_graph = output_func_graphs[i]; - auto gather_actor_name = sub_func_graph->ToString(); - auto actor = FetchActor(gather_actor_name); MS_EXCEPTION_IF_NULL(actor); auto gather_actor = dynamic_cast(actor); - MS_EXCEPTION_IF_NULL(gather_actor); - - gather_actor->branch_id_ = branch_id; - loop_count_actor->branch_id_to_input_controls_num_[branch_id] = graph_to_control_num[sub_func_graph]; - } - - // If the switch actor is linked to the output actor, it will link a control arrow to the loop count actor, - // and this should be recorded. - for (const auto &from_actor : actor_set->switch_actors_) { - MS_EXCEPTION_IF_NULL(from_actor); - auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0); - const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index); - if (iter == graph_compiler_info.origin_outputs_order_.end()) { - continue; - } - loop_count_actor->branch_id_to_input_controls_num_[iter->second.first]++; + gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_); } } @@ -2194,8 +2340,15 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo OutputActor *to_actor = actor_set->output_actor_.get(); MS_EXCEPTION_IF_NULL(to_actor); - for (const auto gather_actor : actor_set->gather_actors_) { - MS_EXCEPTION_IF_NULL(gather_actor); + for (const auto &func_graph_to_branch_id : graph_compiler_info.control_node_parser_->func_graph_to_branch_id_) { + if (func_graph_to_branch_id.second == kMainBranchID) { + continue; + } + + const auto &func_graph = func_graph_to_branch_id.first; + auto actor = FetchActor(func_graph->ToString()); + MS_EXCEPTION_IF_NULL(actor); + auto gather_actor = dynamic_cast(actor); for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) { const auto front_node = gather_actor->data_nodes_[i]; @@ -2205,11 +2358,7 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo continue; } - for (auto &output_position : iter->second.second) { - MS_LOG(INFO) << "Link output node:" << AnfAlgo::GetNodeDebugString(origin_output_with_index.first) - << " branch id:" << iter->second.first << " index:" << output_position - << " for gather actor:" << gather_actor->GetAID(); - + for (auto &output_position : iter->second) { auto op_arrow = std::make_shared(i, to_actor->GetAID(), output_position); gather_actor->output_result_arrows_.emplace_back(op_arrow); const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node]; @@ -2220,9 +2369,9 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo const auto &backend_node = backend_nodes[0].first; if (backend_node->isa()) { std::string actor_name = graph_compiler_info.name_ + "_HostDSActor"; - auto actor = FetchActor(actor_name); - MS_EXCEPTION_IF_NULL(actor); - auto host_ds_actor = dynamic_cast(actor); + auto ds_op_actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(ds_op_actor); + auto host_ds_actor = dynamic_cast(ds_op_actor); MS_EXCEPTION_IF_NULL(host_ds_actor); const auto &data_nodes = host_ds_actor->data_nodes_; @@ -2271,7 +2420,7 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt auto input_data_num = kernel_actor->input_datas_num_; auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size(); if (input_data_num + device_tensor_store_num != input_num) { - MS_LOG(ERROR) << "The input building of " << kernel_actor->GetAID().Name() + MS_LOG(ERROR) << "The input building of " << AnfAlgo::GetNodeDebugString(kernel_actor->kernel_) << " is wrong, input data num: " << input_data_num << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num; return false; @@ -2302,7 +2451,7 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt const auto &loop_count_actor = actor_set->loop_count_actor_; if ((loop_count_actor != nullptr) && (actor_set->data_source_actors_.size() + actor_set->kernel_actors_.size() + actor_set->copy_actors_.size() > 0)) { - if (loop_count_actor->branch_id_to_input_controls_num_[kMainBranchID] == 0) { + if (loop_count_actor->input_controls_num_ == 0) { MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source."; return false; } @@ -2477,6 +2626,16 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf DumpCopyActor(copy_actor.get(), ofs); } + ofs << "\n\n[Gather actors]\n"; + for (const auto &gather_actor : actor_set->gather_actors_) { + DumpGatherActor(gather_actor.get(), ofs); + } + + ofs << "\n\n[Switch actors]\n"; + for (const auto &switch_actor : actor_set->switch_actors_) { + DumpSwitchActor(switch_actor.get(), ofs); + } + ofs << "\n\n[Loop count actor]\n"; const auto &loop_count_actor = actor_set->loop_count_actor_; if (loop_count_actor != nullptr) { @@ -2562,7 +2721,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const { MS_EXCEPTION_IF_NULL(actor); ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ - << "\tinput_controls_num:" << actor->branch_id_to_input_controls_num_.at(kMainBranchID) << "\n"; + << "\tinput_controls_num:" << actor->input_controls_num_ << "\n"; ofs << "\t\toutput_control_arrows:" << (actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() + 1) << "\n "; @@ -2625,8 +2784,8 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ << "\toutputs_num:" << actor->outputs_num_ << "\n"; - ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.at(kMainBranchID).size() << "\n "; - for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_.at(kMainBranchID)) { + ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n "; + for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) { MS_EXCEPTION_IF_NULL(device_tensor_store_key.second); ofs << "\t\t\toutput_node_position:" << device_tensor_store_key.first << "\toutput_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n"; @@ -2745,7 +2904,7 @@ void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &of ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n"; for (const auto &node : actor->input_nodes_) { - ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node) << '\n'; + ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << '\t' << node.second << '\n'; } ofs << "\t\tactor input pos:\n"; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index 931d0765308..7d02f63416f 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -46,7 +46,7 @@ using mindspore::session::KernelWithIndex; // Position of kernel with index, the value pair> means the branch id of the kernel and the pos // of the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there are // multiple branch scenarios, and pos represents the position of the kernel in the branch. -using KernelMapPosition = std::map>, session::KernelWithIndexCmp>; +using KernelMapPosition = std::map, session::KernelWithIndexCmp>; using ActorInfo = std::string; // The second element of pair represents the output index of op actor corresponding to the graph output node. @@ -212,7 +212,8 @@ class GraphScheduler { KernelWithIndex to_kernel_with_input_idx); // 2. The processing of linking control arrows. - void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set); + void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set, + const ControlNodeParserPtr &parser); void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node); // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node. void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node); @@ -227,20 +228,24 @@ class GraphScheduler { // 4. The processing of control flow linking. void LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *actor_set); - void LinkDataArrowForGatherActor(GatherActor *from_actor, KernelActor *to_actor, - KernelWithIndex from_kernel_with_output_idx, - KernelWithIndex to_kernel_with_input_idx); + void LinkDataArrowForGatherActor(GatherActor *from_actor, const AnfNodePtr &front_node, KernelActor *to_actor, + const size_t to_index); void LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info, SwitchActor *actor); // Connect the input of the actor. - void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &input_node, - OpActor *to_actor, const size_t to_index); + void LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, const KernelWithIndex &input_node, + const FuncGraphPtr &from_func_graph, OpActor *to_actor, + const size_t to_index); // When the input of the actor is a call node, the output of the funcgraph called by the call node needs to be // connected. - void LinkDataArrowByCallInput(const GraphCompilerInfo &graph_compiler_info, const AnfNodePtr &call_node, - OpActor *to_actor, const size_t to_index); - void LinkDataArrowForSwitchActor(SwitchActor *from_actor, KernelActor *to_actor, const size_t to_index); - void LinkControlArrowForGatherActor(std::vector *from_actors, LoopCountActor *to_actor, - const std::vector &graphs); + void LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index, const ControlNodeParserPtr &parser, + const FuncGraphPtr &from_func_graph, OpActor *to_actor, + const size_t to_index); + void LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, OpActor *to_actor, + const size_t to_index, const size_t branch_index = SIZE_MAX); + void LinkControlArrowForGatherActor(std::vector *from_actors, + std::vector *kernel_actors, LoopCountActor *to_actor, + const std::vector &graphs, const ControlNodeParserPtr &parser); + void LinkControlArrowForSwitchActor(std::vector *switch_actors, LoopCountActor *to_actor, const KernelMapPosition &origin_outputs_order); // In control flow, there are scenarios where there are multi-branch outputs, and the gather actor needs to diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 9d9f8185325..16989ad22a4 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -712,26 +712,25 @@ std::unique_ptr MindRTBackend::ConstructGraphCompilerInfo(con auto parser = std::make_shared(); parser->Parse(control_nodes_, graphs, device_contexts, root_graph); - // Get all the outputs. In control flow, there may be multiple branch output. runtime::KernelMapPosition outputs_order; size_t outputs_num = 0; - const auto &all_branch_output = parser->FetchAllBranchOutputs(root_graph); - for (int j = 0; j < SizeToInt(all_branch_output.size()); ++j) { - // In general, there is only one output branch, and the branch id is 0 at this time. In the control flow, - // there are multi-branch output scenarios. Different branches may have different weight nodes. When output - // actor run, the corresponding weight node needs to be obtained according to different branches. Therefore, - // the branch of the output nodes needs to be recorded. - const int branch_id = ((all_branch_output.size() == 1 ? runtime::kMainBranchID : (j + runtime::kSubBranchStartID))); - const auto &branch_output = all_branch_output[j]; - size_t position = 0; - auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output); - outputs_num = outputs.size(); - for (const auto &output : outputs) { - if (outputs_order.count(output) == 0) { - outputs_order[output] = {branch_id, {position++}}; - } else { - outputs_order[output].second.emplace_back(position++); - } + const auto &root_output = + AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first; + size_t position = 0; + auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output); + if (runtime::IsCallNode(root_output)) { + std::vector call_nodes; + size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes); + for (size_t i = 0; i < call_output_num; ++i) { + outputs.push_back({root_output, i}); + } + } + outputs_num = outputs.size(); + for (const auto &output : outputs) { + if (outputs_order.count(output) == 0) { + outputs_order[output] = {position++}; + } else { + outputs_order[output].emplace_back(position++); } } @@ -759,9 +758,9 @@ std::unique_ptr MindRTBackend::ConstructGraphCompilerInfo( auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output()); for (const auto &output : outputs) { if (outputs_order.count(output) == 0) { - outputs_order[output] = {runtime::kMainBranchID, {position++}}; + outputs_order[output] = {position++}; } else { - outputs_order[output].second.emplace_back(position++); + outputs_order[output].emplace_back(position++); } } }