!27012 Fetch total front node in kernel graph.

Merge pull request !27012 from gaoyong10/runtime_second14
This commit is contained in:
i-robot 2021-12-02 01:30:12 +00:00 committed by Gitee
commit acd04621fc
3 changed files with 15 additions and 14 deletions

View File

@ -390,6 +390,7 @@ class KernelGraph : public FuncGraph {
void set_is_executing_sink(bool is_executing_sink) { is_executing_sink_ = is_executing_sink; }
bool is_loop_count_sink() const { return is_loop_count_sink_; }
void set_is_loop_count_sink(bool is_loop_count_sink) { is_loop_count_sink_ = is_loop_count_sink; }
const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &front_backend_anf_map() { return front_backend_anf_map_; }
AnfWithOutIndex GetElementInTupleBackendFrontIndexMap(const AnfNodePtr &back_node) {
auto iter = tuple_backend_front_anf_index_map_.find(back_node);

View File

@ -78,7 +78,15 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
// There are two kinds of run conditions for entrance actor:
// 1.Data comes from the data source actor, it is in the form of data arrow.
const auto &data_iter = input_op_datas_.find(sequential_num);
if (data_iter != input_op_datas_.end()) {
const auto &control_iter = input_op_controls_.find(sequential_num);
if (data_iter != input_op_datas_.end() || control_iter != input_op_controls_.end()) {
// If the data comes from the data source actor, use the default branch id.
output_branch_id_ = 0;
if (data_iter == input_op_datas_.end()) {
return;
}
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
@ -90,8 +98,6 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data->data_);
input_device_tensors_[input_data->index_] = input_data->data_;
}
// If the data comes from the data source actor, use the default branch id.
output_branch_id_ = 0;
} else {
// 2.Data comes from the gather actor, it is in the form of data with branch id.
output_branch_id_ = real_parameters_with_branch_id_[sequential_num].front().branch_id_;

View File

@ -844,7 +844,7 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
const auto &cnode = return_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
const auto output_nodes = FetchAllOutputWithIndex(inputs[kReturnInputPos]);
const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
std::vector<const DeviceContext *> return_device_contexts;
for (const auto &output_node : output_nodes) {
@ -909,19 +909,13 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
for (const auto &graph : graphs) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->execution_order().empty()) {
continue;
}
for (auto &kernel : graph->execution_order()) {
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
if (front_node != nullptr) {
front_node_to_kernel_graph_[front_node] = graph;
}
}
const auto &graph_outputs = graph->graph_output_map();
for (const auto &backend_to_front : graph_outputs) {
front_node_to_kernel_graph_[backend_to_front.second.first] = graph;
const auto &front_to_backend_nodes = graph->front_backend_anf_map();
for (const auto &front_to_backend_node : front_to_backend_nodes) {
front_node_to_kernel_graph_[front_to_backend_node.first] = graph;
}
}
}