diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc index 97082263b4f..456ce60269e 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc @@ -2273,79 +2273,38 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, } MS_EXCEPTION_IF_NULL(to_actor); - for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) { - const auto &graph = graph_compiler_info.graphs_[i]; - MS_EXCEPTION_IF_NULL(graph); - auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output()); - std::set> unique_output_positions; - std::set unique_outputs; - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output.first); - if (IsInternalParameter(output.first, graph)) { - MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString(); - continue; - } - (void)unique_outputs.insert(output); + for (const auto &origin_output_order : graph_compiler_info.origin_outputs_order_) { + const auto &front_output_with_index = origin_output_order.first; + if (graph_output_to_actor_.count(front_output_with_index) == 0) { + MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString(); } - for (const auto &unique_output : unique_outputs) { - MS_EXCEPTION_IF_NULL(unique_output.first); - auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(unique_output); - const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index); - if (iter == graph_compiler_info.origin_outputs_order_.end()) { - continue; + const auto &graph_output_pair = graph_output_to_actor_.at(front_output_with_index); + const auto &from_actor = graph_output_pair.first; + const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_output_pair.second); + auto real_from_kernel = output_with_index.first; + auto real_from_index = output_with_index.second; + MS_EXCEPTION_IF_NULL(real_from_kernel); + if (IsPersistentDeviceTensor(real_from_kernel)) { + // In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0 as + // output_idx directly. + real_from_index = 0; + } else { + if (from_actor == nullptr) { + MS_LOG(EXCEPTION) << "Can't find output actor by front node:" << front_output_with_index.first->DebugString() + << ", output node:" << real_from_kernel->DebugString(); } - - // Skip duplicate position. - if (unique_output_positions.count(iter->second) > 0) { - continue; + // Update the real node in the host data source actor. + if (from_actor->type() == KernelTransformType::kHostDataSourceActor) { + auto host_queue_ds_actor = dynamic_cast(from_actor); + MS_EXCEPTION_IF_NULL(host_queue_ds_actor); + auto position = host_queue_ds_actor->FetchNodePosition({real_from_kernel, 0}); + UpdateRefCount(real_from_kernel, real_from_index, true); + real_from_kernel = host_queue_ds_actor->FetchNode(position).first; } - (void)unique_output_positions.insert(iter->second); - // The data arrow need skip the monad node. - const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(unique_output); - MS_EXCEPTION_IF_NULL(output_with_index.first); - for (auto &output_position : iter->second) { - if (output_position >= to_actor->device_contexts_.size()) { - MS_LOG(EXCEPTION) << "The output position is out of range."; - } - to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i]; + } - // The graph output is from device tensor store. - if (IsPersistentDeviceTensor(output_with_index.first)) { - (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first); - if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) { - MS_EXCEPTION_IF_NULL(output_with_index.first); - MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit"; - continue; - } - // In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0 - // as output_idx directly. - auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false); - MS_EXCEPTION_IF_NULL(device_tensor); - // The output actor need use the relevant information of node to create output tensor. - device_tensor->SetNodeIndex(output_with_index.first, 0); - continue; - } - - // The graph output is from kernel actor or data source actor. - auto kernel_type = FetchKernelTransformType( - output_with_index.first, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_); - auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_with_index.first, graph); - if (from_actor == nullptr) { - continue; - } - - auto real_from_kernel = output_with_index.first; - // Update the real node in the host data source actor. - if (kernel_type == KernelTransformType::kHostDataSourceActor) { - auto host_queue_ds_actor = dynamic_cast(from_actor); - MS_EXCEPTION_IF_NULL(host_queue_ds_actor); - auto position = host_queue_ds_actor->FetchNodePosition({output_with_index.first, 0}); - real_from_kernel = host_queue_ds_actor->FetchNode(position).first; - UpdateRefCount(output_with_index.first, output_with_index.second, true); - } - SchedulerHelper::AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second, - output_position); - } + for (auto &output_position : origin_output_order.second) { + SchedulerHelper::AddResultArrow(from_actor, to_actor, real_from_kernel, real_from_index, output_position); } } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/scheduler_helper.cc b/mindspore/ccsrc/runtime/graph_scheduler/scheduler_helper.cc index 1c6b28b3ee7..80fcd4b8d75 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/scheduler_helper.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/scheduler_helper.cc @@ -185,23 +185,37 @@ void SchedulerHelper::AddDataArrow(AbstractActor *const from_actor, AbstractActo void SchedulerHelper::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) { - MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(to_actor); MS_EXCEPTION_IF_NULL(from_kernel); - auto result_arrow = std::make_shared(from_output_index, to_actor->GetAID(), output_position); - (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow); - (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel); - to_actor->input_datas_num_++; - (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get())); + if (from_actor == nullptr) { + (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, from_kernel); + } else { + auto result_arrow = std::make_shared(from_output_index, to_actor->GetAID(), output_position); + (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow); + (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel); + to_actor->input_datas_num_++; + (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get())); + } + if (!AnfAlgo::OutputAddrExist(from_kernel, from_output_index, false)) { + MS_LOG(WARNING) << from_kernel->DebugString() << " device address not exit"; + return; + } auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false); MS_EXCEPTION_IF_NULL(device_tensor); // The output actor need use the relevant information of node to create output tensor. device_tensor->SetNodeIndex(from_kernel, from_output_index); - // The device tensor of graph out need be taken over by host tensor, so set the max reference count. UpdateRefCount(device_tensor.get(), true); + + // Set the device contexts of to_actor. + if (output_position >= to_actor->device_contexts_.size()) { + MS_LOG(EXCEPTION) << "The output position is out of range."; + } + auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {device_tensor->device_name(), device_tensor->device_id()}); + to_actor->device_contexts_[output_position] = device_context; } void SchedulerHelper::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {