diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 0913b8efe18..ee126d2ec1b 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -45,6 +45,7 @@ #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/ad/grad.h" #include "frontend/optimizer/py_pass_manager.h" +#include "frontend/optimizer/irpass/parameter_eliminate.h" #include "utils/ms_context.h" #include "utils/ms_utils.h" #include "backend/graph_compiler/transform.h" @@ -527,6 +528,19 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) { return InferenceOptPreparePass(res); } +bool EliminateUnusedParameterAction(const ResourcePtr &res) { + static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1"); + if (!transform_tail_call_to_parallel_call) { + return true; + } + MS_EXCEPTION_IF_NULL(res); + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + bool changed = opt::irpass::ParameterEliminator()(func_graph, nullptr); + MS_LOG(DEBUG) << "Eliminate parameter, changed: " << changed; + return true; +} + bool AbstractSpecializeAction(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res); if (res->func_graph() == nullptr) { @@ -1335,15 +1349,17 @@ static std::vector CommonPipeline() { } (void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); - // Evaluate type and shape, and specialize + // Eliminate unused parameters before renormalize. + (void)actions.emplace_back(std::make_pair("elininate_unused_parameter", EliminateUnusedParameterAction)); + // Evaluate type and shape, and specialize. (void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); // Auto-monad for side-effects handling. (void)actions.emplace_back(std::make_pair("auto_monad", AutoMonadAction)); - // Do data structure simplifications and inline + // Do data structure simplifications and inline. (void)actions.emplace_back(std::make_pair("inline", OptInlineAction)); - // Add pre-ad, post-inline python pass stub + // Add pre-ad, post-inline python pass stub. (void)actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub)); - // Do PipelineSplit + // Do PipelineSplit action. (void)actions.emplace_back(std::make_pair("pipeline_split", PipelineSplitAction)); return actions; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 332d3ab8a80..8e09fdbd622 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -173,15 +173,19 @@ void Parser::TransformParallelCall() { auto former_call_graph = call_graphs_pair.first->func_graph(); MS_EXCEPTION_IF_NULL(call_graphs_pair.second); auto middle_call_graph = call_graphs_pair.second->func_graph(); + // Transform the call of {middle_graph -> latter_graph}. + auto middle_graph_return = middle_call_graph->get_return(); + if (middle_graph_return == nullptr) { + MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString(); + continue; + } constexpr auto recur_3 = 3; MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3) << ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}"; - - // Transform the call of {middle_graph -> latter_graph}. - auto middle_graph_return = middle_call_graph->get_return(); - MS_EXCEPTION_IF_NULL(middle_graph_return); auto middle_graph_output = middle_call_graph->output(); - MS_EXCEPTION_IF_NULL(middle_graph_output); + if (middle_graph_output == nullptr) { + MS_LOG(EXCEPTION) << "middle_graph_output is null, middle_call_graph: " << middle_call_graph->ToString(); + } auto middle_graph_output_cnode = dyn_cast(middle_graph_output); MS_EXCEPTION_IF_NULL(middle_graph_output_cnode); if (IsDependOfIsolatedNodes(middle_graph_output_cnode)) { @@ -210,7 +214,7 @@ void Parser::TransformParallelCall() { auto latter_call_graph = GetValueNode(latter_graph_node); if (latter_call_graph == nullptr) { constexpr auto recur_2 = 2; - MS_LOG(DEBUG) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2); + MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2); continue; } if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index d06b1b5f58c..3f0856e8f8c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -536,13 +536,21 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial } ValuePtrList elements; for (size_t i = 0; i < (*flags).size(); ++i) { + ValuePtr old_sequence_value = sequence_value->value()[i]; + auto old_sequence_str_value = old_sequence_value->cast(); if (!(*flags)[i]) { auto zero = MakeValue(0); (void)elements.emplace_back(zero); MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs[" << index << "] of " << cnode->DebugString(); + } else if (old_sequence_str_value != nullptr && old_sequence_str_value->value() == kDeadNodeName) { + auto zero = MakeValue(0); + elements.emplace_back(zero); + (*flags)[i] = false; // Change the use flag as 0. + MS_LOG(DEBUG) << "Erase elements[" << i << "] DeadNode as zero for " << old_input->DebugString() + << ", which is inputs[" << index << "] of " << cnode->DebugString(); } else { - (void)elements.emplace_back(sequence_value->value()[i]); + (void)elements.emplace_back(old_sequence_value); } } auto new_sequence_value = std::make_shared(elements); @@ -601,12 +609,20 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) { (void)inputs.emplace_back(cnode->input(0)); for (size_t i = 0; i < (*flags).size(); ++i) { auto old_input = cnode->input(i + 1); + auto old_input_value = GetValueNode(old_input); if (!(*flags)[i]) { auto zero_value = NewValueNode(MakeValue(0)); zero_value->set_abstract(std::make_shared(std::make_shared(0))); (void)inputs.emplace_back(zero_value); constexpr int recursive_level = 2; MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << cnode->DebugString(recursive_level); + } else if (old_input_value != nullptr && old_input_value->value() == kDeadNodeName) { + auto zero_value = NewValueNode(MakeValue(0)); + zero_value->set_abstract(std::make_shared(std::make_shared(0))); + inputs.emplace_back(zero_value); + (*flags)[i] = false; // Change the use flag as 0. + constexpr int recursive_level = 2; + MS_LOG(DEBUG) << "Erase elements[" << i << "] DeadNode as zero for " << cnode->DebugString(recursive_level); } else { (void)inputs.emplace_back(old_input); } diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index e8b55f208e5..6c504d2bd45 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -252,7 +252,7 @@ void FuncGraph::DropNode(const AnfNodePtr &node) { (void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end()); } // Remove the node from order list. - if (graph) { + if (graph != nullptr) { graph->EraseUnusedNodeInOrder(node); } }