!46307 optimize output actor link to solve the inconsistency between front and back outputs
Merge pull request !46307 from limingqi107/r2.0.0-alpha
This commit is contained in:
commit
44e82bc68b
|
@ -2273,79 +2273,38 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
std::set<std::vector<size_t>> unique_output_positions;
|
||||
std::set<KernelWithIndex> unique_outputs;
|
||||
for (const auto &output : outputs) {
|
||||
MS_EXCEPTION_IF_NULL(output.first);
|
||||
if (IsInternalParameter(output.first, graph)) {
|
||||
MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
|
||||
continue;
|
||||
for (const auto &origin_output_order : graph_compiler_info.origin_outputs_order_) {
|
||||
const auto &front_output_with_index = origin_output_order.first;
|
||||
if (graph_output_to_actor_.count(front_output_with_index) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
|
||||
}
|
||||
(void)unique_outputs.insert(output);
|
||||
}
|
||||
for (const auto &unique_output : unique_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(unique_output.first);
|
||||
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(unique_output);
|
||||
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
|
||||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip duplicate position.
|
||||
if (unique_output_positions.count(iter->second) > 0) {
|
||||
continue;
|
||||
}
|
||||
(void)unique_output_positions.insert(iter->second);
|
||||
// The data arrow need skip the monad node.
|
||||
const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(unique_output);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
for (auto &output_position : iter->second) {
|
||||
if (output_position >= to_actor->device_contexts_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The output position is out of range.";
|
||||
}
|
||||
to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
|
||||
|
||||
// The graph output is from device tensor store.
|
||||
if (IsPersistentDeviceTensor(output_with_index.first)) {
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
|
||||
if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) {
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit";
|
||||
continue;
|
||||
}
|
||||
// In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0
|
||||
// as output_idx directly.
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// The output actor need use the relevant information of node to create output tensor.
|
||||
device_tensor->SetNodeIndex(output_with_index.first, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
// The graph output is from kernel actor or data source actor.
|
||||
auto kernel_type = FetchKernelTransformType(
|
||||
output_with_index.first, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_);
|
||||
auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_with_index.first, graph);
|
||||
if (from_actor == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &graph_output_pair = graph_output_to_actor_.at(front_output_with_index);
|
||||
const auto &from_actor = graph_output_pair.first;
|
||||
const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(graph_output_pair.second);
|
||||
auto real_from_kernel = output_with_index.first;
|
||||
auto real_from_index = output_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_from_kernel);
|
||||
if (IsPersistentDeviceTensor(real_from_kernel)) {
|
||||
// In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0 as
|
||||
// output_idx directly.
|
||||
real_from_index = 0;
|
||||
} else {
|
||||
if (from_actor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can't find output actor by front node:" << front_output_with_index.first->DebugString()
|
||||
<< ", output node:" << real_from_kernel->DebugString();
|
||||
}
|
||||
// Update the real node in the host data source actor.
|
||||
if (kernel_type == KernelTransformType::kHostDataSourceActor) {
|
||||
if (from_actor->type() == KernelTransformType::kHostDataSourceActor) {
|
||||
auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
|
||||
auto position = host_queue_ds_actor->FetchNodePosition({output_with_index.first, 0});
|
||||
auto position = host_queue_ds_actor->FetchNodePosition({real_from_kernel, 0});
|
||||
UpdateRefCount(real_from_kernel, real_from_index, true);
|
||||
real_from_kernel = host_queue_ds_actor->FetchNode(position).first;
|
||||
UpdateRefCount(output_with_index.first, output_with_index.second, true);
|
||||
}
|
||||
SchedulerHelper::AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second,
|
||||
output_position);
|
||||
}
|
||||
|
||||
for (auto &output_position : origin_output_order.second) {
|
||||
SchedulerHelper::AddResultArrow(from_actor, to_actor, real_from_kernel, real_from_index, output_position);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -185,23 +185,37 @@ void SchedulerHelper::AddDataArrow(AbstractActor *const from_actor, AbstractActo
|
|||
|
||||
void SchedulerHelper::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
|
||||
const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_kernel);
|
||||
|
||||
if (from_actor == nullptr) {
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, from_kernel);
|
||||
} else {
|
||||
auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
|
||||
(void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
|
||||
(void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
|
||||
to_actor->input_datas_num_++;
|
||||
(void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get()));
|
||||
}
|
||||
|
||||
if (!AnfAlgo::OutputAddrExist(from_kernel, from_output_index, false)) {
|
||||
MS_LOG(WARNING) << from_kernel->DebugString() << " device address not exit";
|
||||
return;
|
||||
}
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// The output actor need use the relevant information of node to create output tensor.
|
||||
device_tensor->SetNodeIndex(from_kernel, from_output_index);
|
||||
|
||||
// The device tensor of graph out need be taken over by host tensor, so set the max reference count.
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
||||
// Set the device contexts of to_actor.
|
||||
if (output_position >= to_actor->device_contexts_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The output position is out of range.";
|
||||
}
|
||||
auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device_tensor->device_name(), device_tensor->device_id()});
|
||||
to_actor->device_contexts_[output_position] = device_context;
|
||||
}
|
||||
|
||||
void SchedulerHelper::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
|
||||
|
|
Loading…
Reference in New Issue