!46307 optimize output actor link to solve the inconsistency between front and back outputs

Merge pull request !46307 from limingqi107/r2.0.0-alpha
This commit is contained in:
i-robot 2022-12-01 11:43:01 +00:00 committed by Gitee
commit 44e82bc68b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 49 additions and 76 deletions

View File

@ -2273,79 +2273,38 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
} }
MS_EXCEPTION_IF_NULL(to_actor); MS_EXCEPTION_IF_NULL(to_actor);
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) { for (const auto &origin_output_order : graph_compiler_info.origin_outputs_order_) {
const auto &graph = graph_compiler_info.graphs_[i]; const auto &front_output_with_index = origin_output_order.first;
MS_EXCEPTION_IF_NULL(graph); if (graph_output_to_actor_.count(front_output_with_index) == 0) {
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output()); MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
std::set<std::vector<size_t>> unique_output_positions;
std::set<KernelWithIndex> unique_outputs;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output.first);
if (IsInternalParameter(output.first, graph)) {
MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
continue;
} }
(void)unique_outputs.insert(output); const auto &graph_output_pair = graph_output_to_actor_.at(front_output_with_index);
} const auto &from_actor = graph_output_pair.first;
for (const auto &unique_output : unique_outputs) { const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_output_pair.second);
MS_EXCEPTION_IF_NULL(unique_output.first);
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(unique_output);
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
continue;
}
// Skip duplicate position.
if (unique_output_positions.count(iter->second) > 0) {
continue;
}
(void)unique_output_positions.insert(iter->second);
// The data arrow need skip the monad node.
const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(unique_output);
MS_EXCEPTION_IF_NULL(output_with_index.first);
for (auto &output_position : iter->second) {
if (output_position >= to_actor->device_contexts_.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range.";
}
to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
// The graph output is from device tensor store.
if (IsPersistentDeviceTensor(output_with_index.first)) {
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) {
MS_EXCEPTION_IF_NULL(output_with_index.first);
MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit";
continue;
}
// In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0
// as output_idx directly.
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
// The output actor need use the relevant information of node to create output tensor.
device_tensor->SetNodeIndex(output_with_index.first, 0);
continue;
}
// The graph output is from kernel actor or data source actor.
auto kernel_type = FetchKernelTransformType(
output_with_index.first, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_);
auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_with_index.first, graph);
if (from_actor == nullptr) {
continue;
}
auto real_from_kernel = output_with_index.first; auto real_from_kernel = output_with_index.first;
auto real_from_index = output_with_index.second;
MS_EXCEPTION_IF_NULL(real_from_kernel);
if (IsPersistentDeviceTensor(real_from_kernel)) {
// In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0 as
// output_idx directly.
real_from_index = 0;
} else {
if (from_actor == nullptr) {
MS_LOG(EXCEPTION) << "Can't find output actor by front node:" << front_output_with_index.first->DebugString()
<< ", output node:" << real_from_kernel->DebugString();
}
// Update the real node in the host data source actor. // Update the real node in the host data source actor.
if (kernel_type == KernelTransformType::kHostDataSourceActor) { if (from_actor->type() == KernelTransformType::kHostDataSourceActor) {
auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor); auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
MS_EXCEPTION_IF_NULL(host_queue_ds_actor); MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
auto position = host_queue_ds_actor->FetchNodePosition({output_with_index.first, 0}); auto position = host_queue_ds_actor->FetchNodePosition({real_from_kernel, 0});
UpdateRefCount(real_from_kernel, real_from_index, true);
real_from_kernel = host_queue_ds_actor->FetchNode(position).first; real_from_kernel = host_queue_ds_actor->FetchNode(position).first;
UpdateRefCount(output_with_index.first, output_with_index.second, true);
} }
SchedulerHelper::AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second,
output_position);
} }
for (auto &output_position : origin_output_order.second) {
SchedulerHelper::AddResultArrow(from_actor, to_actor, real_from_kernel, real_from_index, output_position);
} }
} }
} }

View File

@ -185,23 +185,37 @@ void SchedulerHelper::AddDataArrow(AbstractActor *const from_actor, AbstractActo
void SchedulerHelper::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor, void SchedulerHelper::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) { const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor); MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(from_kernel); MS_EXCEPTION_IF_NULL(from_kernel);
if (from_actor == nullptr) {
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, from_kernel);
} else {
auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position); auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
(void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow); (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
(void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel); (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
to_actor->input_datas_num_++; to_actor->input_datas_num_++;
(void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get())); (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get()));
}
if (!AnfAlgo::OutputAddrExist(from_kernel, from_output_index, false)) {
MS_LOG(WARNING) << from_kernel->DebugString() << " device address not exit";
return;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false); auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor); MS_EXCEPTION_IF_NULL(device_tensor);
// The output actor need use the relevant information of node to create output tensor. // The output actor need use the relevant information of node to create output tensor.
device_tensor->SetNodeIndex(from_kernel, from_output_index); device_tensor->SetNodeIndex(from_kernel, from_output_index);
// The device tensor of graph out need be taken over by host tensor, so set the max reference count. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
UpdateRefCount(device_tensor.get(), true); UpdateRefCount(device_tensor.get(), true);
// Set the device contexts of to_actor.
if (output_position >= to_actor->device_contexts_.size()) {
MS_LOG(EXCEPTION) << "The output position is out of range.";
}
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_tensor->device_name(), device_tensor->device_id()});
to_actor->device_contexts_[output_position] = device_context;
} }
void SchedulerHelper::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) { void SchedulerHelper::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {