!30476 Fix 'if parallel call' issues.

Merge pull request !30476 from 张清华/opt_if_parallel2
This commit is contained in:
i-robot 2022-02-25 04:35:39 +00:00 committed by Gitee
commit 4a49f43128
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 48 additions and 12 deletions

View File

@ -45,6 +45,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "utils/ms_context.h"
#include "utils/ms_utils.h"
#include "backend/graph_compiler/transform.h"
@ -527,6 +528,19 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
return InferenceOptPreparePass(res);
}
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_PARALLEL_CALL") == "1");
if (!transform_tail_call_to_parallel_call) {
return true;
}
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = opt::irpass::ParameterEliminator()(func_graph, nullptr);
MS_LOG(DEBUG) << "Eliminate parameter, changed: " << changed;
return true;
}
bool AbstractSpecializeAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (res->func_graph() == nullptr) {
@ -1335,15 +1349,17 @@ static std::vector<ActionItem> CommonPipeline() {
}
(void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
// Eliminate unused parameters before renormalize.
(void)actions.emplace_back(std::make_pair("elininate_unused_parameter", EliminateUnusedParameterAction));
// Evaluate type and shape, and specialize.
(void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
// Auto-monad for side-effects handling.
(void)actions.emplace_back(std::make_pair("auto_monad", AutoMonadAction));
// Do data structure simplifications and inline
// Do data structure simplifications and inline.
(void)actions.emplace_back(std::make_pair("inline", OptInlineAction));
// Add pre-ad, post-inline python pass stub
// Add pre-ad, post-inline python pass stub.
(void)actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
// Do PipelineSplit
// Do PipelineSplit action.
(void)actions.emplace_back(std::make_pair("pipeline_split", PipelineSplitAction));
return actions;

View File

@ -173,15 +173,19 @@ void Parser::TransformParallelCall() {
auto former_call_graph = call_graphs_pair.first->func_graph();
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
auto middle_call_graph = call_graphs_pair.second->func_graph();
// Transform the call of {middle_graph -> latter_graph}.
auto middle_graph_return = middle_call_graph->get_return();
if (middle_graph_return == nullptr) {
MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
continue;
}
constexpr auto recur_3 = 3;
MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
// Transform the call of {middle_graph -> latter_graph}.
auto middle_graph_return = middle_call_graph->get_return();
MS_EXCEPTION_IF_NULL(middle_graph_return);
auto middle_graph_output = middle_call_graph->output();
MS_EXCEPTION_IF_NULL(middle_graph_output);
if (middle_graph_output == nullptr) {
MS_LOG(EXCEPTION) << "middle_graph_output is null, middle_call_graph: " << middle_call_graph->ToString();
}
auto middle_graph_output_cnode = dyn_cast<CNode>(middle_graph_output);
MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
if (IsDependOfIsolatedNodes(middle_graph_output_cnode)) {
@ -210,7 +214,7 @@ void Parser::TransformParallelCall() {
auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
if (latter_call_graph == nullptr) {
constexpr auto recur_2 = 2;
MS_LOG(DEBUG) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
continue;
}
if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {

View File

@ -536,13 +536,21 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
}
ValuePtrList elements;
for (size_t i = 0; i < (*flags).size(); ++i) {
ValuePtr old_sequence_value = sequence_value->value()[i];
auto old_sequence_str_value = old_sequence_value->cast<StringImmPtr>();
if (!(*flags)[i]) {
auto zero = MakeValue(0);
(void)elements.emplace_back(zero);
MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
<< index << "] of " << cnode->DebugString();
} else if (old_sequence_str_value != nullptr && old_sequence_str_value->value() == kDeadNodeName) {
auto zero = MakeValue(0);
elements.emplace_back(zero);
(*flags)[i] = false; // Change the use flag as 0.
MS_LOG(DEBUG) << "Erase elements[" << i << "] DeadNode as zero for " << old_input->DebugString()
<< ", which is inputs[" << index << "] of " << cnode->DebugString();
} else {
(void)elements.emplace_back(sequence_value->value()[i]);
(void)elements.emplace_back(old_sequence_value);
}
}
auto new_sequence_value = std::make_shared<T>(elements);
@ -601,12 +609,20 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) {
(void)inputs.emplace_back(cnode->input(0));
for (size_t i = 0; i < (*flags).size(); ++i) {
auto old_input = cnode->input(i + 1);
auto old_input_value = GetValueNode<StringImmPtr>(old_input);
if (!(*flags)[i]) {
auto zero_value = NewValueNode(MakeValue(0));
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
(void)inputs.emplace_back(zero_value);
constexpr int recursive_level = 2;
MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << cnode->DebugString(recursive_level);
} else if (old_input_value != nullptr && old_input_value->value() == kDeadNodeName) {
auto zero_value = NewValueNode(MakeValue(0));
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
inputs.emplace_back(zero_value);
(*flags)[i] = false; // Change the use flag as 0.
constexpr int recursive_level = 2;
MS_LOG(DEBUG) << "Erase elements[" << i << "] DeadNode as zero for " << cnode->DebugString(recursive_level);
} else {
(void)inputs.emplace_back(old_input);
}

View File

@ -252,7 +252,7 @@ void FuncGraph::DropNode(const AnfNodePtr &node) {
(void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
}
// Remove the node from order list.
if (graph) {
if (graph != nullptr) {
graph->EraseUnusedNodeInOrder(node);
}
}