!27012 Fetch total front node in kernel graph.
Merge pull request !27012 from gaoyong10/runtime_second14
This commit is contained in:
commit
acd04621fc
|
@ -390,6 +390,7 @@ class KernelGraph : public FuncGraph {
|
|||
void set_is_executing_sink(bool is_executing_sink) { is_executing_sink_ = is_executing_sink; }
|
||||
bool is_loop_count_sink() const { return is_loop_count_sink_; }
|
||||
void set_is_loop_count_sink(bool is_loop_count_sink) { is_loop_count_sink_ = is_loop_count_sink; }
|
||||
const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &front_backend_anf_map() { return front_backend_anf_map_; }
|
||||
|
||||
AnfWithOutIndex GetElementInTupleBackendFrontIndexMap(const AnfNodePtr &back_node) {
|
||||
auto iter = tuple_backend_front_anf_index_map_.find(back_node);
|
||||
|
|
|
@ -78,7 +78,15 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
// There are two kinds of run conditions for entrance actor:
|
||||
// 1.Data comes from the data source actor, it is in the form of data arrow.
|
||||
const auto &data_iter = input_op_datas_.find(sequential_num);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
const auto &control_iter = input_op_controls_.find(sequential_num);
|
||||
if (data_iter != input_op_datas_.end() || control_iter != input_op_controls_.end()) {
|
||||
// If the data comes from the data source actor, use the default branch id.
|
||||
output_branch_id_ = 0;
|
||||
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &input_data : data_iter->second) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
|
||||
|
@ -90,8 +98,6 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
input_device_tensors_[input_data->index_] = input_data->data_;
|
||||
}
|
||||
// If the data comes from the data source actor, use the default branch id.
|
||||
output_branch_id_ = 0;
|
||||
} else {
|
||||
// 2.Data comes from the gather actor, it is in the form of data with branch id.
|
||||
output_branch_id_ = real_parameters_with_branch_id_[sequential_num].front().branch_id_;
|
||||
|
|
|
@ -844,7 +844,7 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
|
|||
const auto &cnode = return_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
const auto output_nodes = FetchAllOutputWithIndex(inputs[kReturnInputPos]);
|
||||
const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
|
||||
std::vector<const DeviceContext *> return_device_contexts;
|
||||
|
||||
for (const auto &output_node : output_nodes) {
|
||||
|
@ -909,19 +909,13 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
|
|||
|
||||
void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
|
||||
for (const auto &graph : graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->execution_order().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto &kernel : graph->execution_order()) {
|
||||
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
|
||||
if (front_node != nullptr) {
|
||||
front_node_to_kernel_graph_[front_node] = graph;
|
||||
}
|
||||
}
|
||||
const auto &graph_outputs = graph->graph_output_map();
|
||||
for (const auto &backend_to_front : graph_outputs) {
|
||||
front_node_to_kernel_graph_[backend_to_front.second.first] = graph;
|
||||
const auto &front_to_backend_nodes = graph->front_backend_anf_map();
|
||||
for (const auto &front_to_backend_node : front_to_backend_nodes) {
|
||||
front_node_to_kernel_graph_[front_to_backend_node.first] = graph;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue