forked from mindspore-Ecosystem/mindspore
!30476 Fix 'if parallel call' issues.
Merge pull request !30476 from 张清华/opt_if_parallel2
This commit is contained in:
commit
4a49f43128
|
@ -45,6 +45,7 @@
|
|||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/ad/grad.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "frontend/optimizer/irpass/parameter_eliminate.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/graph_compiler/transform.h"
|
||||
|
@ -527,6 +528,19 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
|||
return InferenceOptPreparePass(res);
|
||||
}
|
||||
|
||||
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
|
||||
if (!transform_tail_call_to_parallel_call) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = opt::irpass::ParameterEliminator()(func_graph, nullptr);
|
||||
MS_LOG(DEBUG) << "Eliminate parameter, changed: " << changed;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AbstractSpecializeAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->func_graph() == nullptr) {
|
||||
|
@ -1335,15 +1349,17 @@ static std::vector<ActionItem> CommonPipeline() {
|
|||
}
|
||||
|
||||
(void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
||||
// Evaluate type and shape, and specialize
|
||||
// Eliminate unused parameters before renormalize.
|
||||
(void)actions.emplace_back(std::make_pair("elininate_unused_parameter", EliminateUnusedParameterAction));
|
||||
// Evaluate type and shape, and specialize.
|
||||
(void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
||||
// Auto-monad for side-effects handling.
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad", AutoMonadAction));
|
||||
// Do data structure simplifications and inline
|
||||
// Do data structure simplifications and inline.
|
||||
(void)actions.emplace_back(std::make_pair("inline", OptInlineAction));
|
||||
// Add pre-ad, post-inline python pass stub
|
||||
// Add pre-ad, post-inline python pass stub.
|
||||
(void)actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
|
||||
// Do PipelineSplit
|
||||
// Do PipelineSplit action.
|
||||
(void)actions.emplace_back(std::make_pair("pipeline_split", PipelineSplitAction));
|
||||
|
||||
return actions;
|
||||
|
|
|
@ -173,15 +173,19 @@ void Parser::TransformParallelCall() {
|
|||
auto former_call_graph = call_graphs_pair.first->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
|
||||
auto middle_call_graph = call_graphs_pair.second->func_graph();
|
||||
// Transform the call of {middle_graph -> latter_graph}.
|
||||
auto middle_graph_return = middle_call_graph->get_return();
|
||||
if (middle_graph_return == nullptr) {
|
||||
MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
constexpr auto recur_3 = 3;
|
||||
MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
|
||||
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
|
||||
|
||||
// Transform the call of {middle_graph -> latter_graph}.
|
||||
auto middle_graph_return = middle_call_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(middle_graph_return);
|
||||
auto middle_graph_output = middle_call_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(middle_graph_output);
|
||||
if (middle_graph_output == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "middle_graph_output is null, middle_call_graph: " << middle_call_graph->ToString();
|
||||
}
|
||||
auto middle_graph_output_cnode = dyn_cast<CNode>(middle_graph_output);
|
||||
MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
|
||||
if (IsDependOfIsolatedNodes(middle_graph_output_cnode)) {
|
||||
|
@ -210,7 +214,7 @@ void Parser::TransformParallelCall() {
|
|||
auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
|
||||
if (latter_call_graph == nullptr) {
|
||||
constexpr auto recur_2 = 2;
|
||||
MS_LOG(DEBUG) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
|
||||
MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
|
||||
continue;
|
||||
}
|
||||
if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
|
||||
|
|
|
@ -536,13 +536,21 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
|
|||
}
|
||||
ValuePtrList elements;
|
||||
for (size_t i = 0; i < (*flags).size(); ++i) {
|
||||
ValuePtr old_sequence_value = sequence_value->value()[i];
|
||||
auto old_sequence_str_value = old_sequence_value->cast<StringImmPtr>();
|
||||
if (!(*flags)[i]) {
|
||||
auto zero = MakeValue(0);
|
||||
(void)elements.emplace_back(zero);
|
||||
MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
|
||||
<< index << "] of " << cnode->DebugString();
|
||||
} else if (old_sequence_str_value != nullptr && old_sequence_str_value->value() == kDeadNodeName) {
|
||||
auto zero = MakeValue(0);
|
||||
elements.emplace_back(zero);
|
||||
(*flags)[i] = false; // Change the use flag as 0.
|
||||
MS_LOG(DEBUG) << "Erase elements[" << i << "] DeadNode as zero for " << old_input->DebugString()
|
||||
<< ", which is inputs[" << index << "] of " << cnode->DebugString();
|
||||
} else {
|
||||
(void)elements.emplace_back(sequence_value->value()[i]);
|
||||
(void)elements.emplace_back(old_sequence_value);
|
||||
}
|
||||
}
|
||||
auto new_sequence_value = std::make_shared<T>(elements);
|
||||
|
@ -601,12 +609,20 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) {
|
|||
(void)inputs.emplace_back(cnode->input(0));
|
||||
for (size_t i = 0; i < (*flags).size(); ++i) {
|
||||
auto old_input = cnode->input(i + 1);
|
||||
auto old_input_value = GetValueNode<StringImmPtr>(old_input);
|
||||
if (!(*flags)[i]) {
|
||||
auto zero_value = NewValueNode(MakeValue(0));
|
||||
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
|
||||
(void)inputs.emplace_back(zero_value);
|
||||
constexpr int recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << cnode->DebugString(recursive_level);
|
||||
} else if (old_input_value != nullptr && old_input_value->value() == kDeadNodeName) {
|
||||
auto zero_value = NewValueNode(MakeValue(0));
|
||||
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
|
||||
inputs.emplace_back(zero_value);
|
||||
(*flags)[i] = false; // Change the use flag as 0.
|
||||
constexpr int recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "Erase elements[" << i << "] DeadNode as zero for " << cnode->DebugString(recursive_level);
|
||||
} else {
|
||||
(void)inputs.emplace_back(old_input);
|
||||
}
|
||||
|
|
|
@ -252,7 +252,7 @@ void FuncGraph::DropNode(const AnfNodePtr &node) {
|
|||
(void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
|
||||
}
|
||||
// Remove the node from order list.
|
||||
if (graph) {
|
||||
if (graph != nullptr) {
|
||||
graph->EraseUnusedNodeInOrder(node);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue