!29662 Keep not free sequence nodes before to call sequence abstract PurifyElements.

Merge pull request !29662 from 张清华/eliminate_tuple_unused_item
This commit is contained in:
i-robot 2022-01-30 01:26:03 +00:00 committed by Gitee
commit d80625be6b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 61 additions and 61 deletions

View File

@ -58,39 +58,11 @@ bool IsSideEffectOp(const AnfNodePtr &node) {
return effect_info.memory || effect_info.io;
}
void CheckSwitchWithSideEffect(const FuncGraphPtr &fg) {
AnfNodePtr switch_node = nullptr;
AnfNodePtr side_effect_node = nullptr;
auto all_graphs = fg->func_graphs_used_total();
all_graphs.add(fg);
for (auto &child_fg : all_graphs) {
for (const auto &node : child_fg->nodes()) {
if (switch_node == nullptr && IsPrimitiveCNode(node, prim::kPrimSwitch)) {
switch_node = node;
}
if (side_effect_node == nullptr && IsSideEffectOp(node)) {
side_effect_node = node;
}
if (switch_node != nullptr && side_effect_node != nullptr) {
MS_LOG(ERROR)
<< "Control flow with side effect op[" << GetCNodeFuncName(side_effect_node->cast<CNodePtr>())
<< "] in training situation is not supported and grads may be wrong. Please remove the control flow "
"statement or the side effect op.\n"
<< " Side effect node:" << side_effect_node->DebugString();
return;
}
}
}
}
AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) {
AnfNodePtr expanded_node = nullptr;
if (IsValueNode<FuncGraph>(vnode)) {
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
// If a control flow network has side effect ops inside, which is not supported now, a error will be raised to
// alert wrong grads.
CheckSwitchWithSideEffect(func_graph);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now";
auto newfg = ad::Grad(func_graph, optimizer);

View File

@ -89,14 +89,13 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
// Call PurifyElements() to purify tuple/list elements.
static const auto enable_only_mark_unused_element = (common::GetEnv("MS_DEV_DDE_ONLY_MARK") == "1");
if (!enable_only_mark_unused_element) {
for (auto &p : sequence_abstract_list_) {
auto &sequence_abs = p.first;
for (auto &abstract_and_node : sequence_abstract_list_) {
auto &sequence_abs = abstract_and_node.first;
if (!sequence_abs->PurifyElements()) {
MS_LOG(INFO) << "Purify elements failed, node: " << p.second->DebugString();
MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
<< ", node: " << abstract_and_node.second->DebugString();
}
}
// Clear all nodes after purify abstract.
sequence_nodes_replaced_list_.clear();
}
return res;
}
@ -460,6 +459,29 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
return;
}
// Since the 'old_node' may not equal to 'old_abs' sequence node,
// if the new_node is built by the abstract of 'forward old node',
// we just set 'new_node' as 'old_abs' sequence node here.
if (IsValueNode<ValueTuple>(new_node) || IsValueNode<ValueList>(new_node)) {
// Just find a valid sequence node.
std::shared_ptr<std::vector<bool>> flags = nullptr;
for (auto &weak_node : *old_sequence_abs->sequence_nodes()) {
auto sequence_node = weak_node.lock();
if (sequence_node == nullptr) {
continue;
}
flags = GetSequenceNodeElementsUseFlags(sequence_node);
}
// Copy the flags to new node, and set new node to sequence abstract.
// Actually, here we needn't require unique sequence nodes pointer between abstract any more.
if (flags != nullptr) {
SetSequenceNodeElementsUseFlags(new_node, flags);
old_sequence_abs->InsertSequenceNode(new_node);
}
return;
}
for (auto &weak_node : *old_sequence_abs->sequence_nodes()) {
auto sequence_node = weak_node.lock();
if (sequence_node == nullptr) {
@ -482,6 +504,8 @@ void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
if (old_abs == new_abs) {
continue;
}
MS_LOG(ERROR) << "New abstract, " << old_node->DebugString() << " --> " << new_node->DebugString()
<< ", elements_use_flags: " << (*flags);
AbstractSequencePtr new_sequence_abs = dyn_cast<AbstractSequence>(new_abs);
if (new_sequence_abs == nullptr) {
MS_LOG(EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString();
@ -531,7 +555,7 @@ void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecial
// Always reset tuple value node's use flags as non-use.
SetSequenceNodeElementsUseFlags(new_input, flags);
MS_LOG(DEBUG) << "Update ValueTuple/ValueList, " << old_input->DebugString() << " --> " << new_input->DebugString()
<< ", which is inputs[" << index << "] of " << cnode->DebugString();
<< ", which is inputs[" << index << "] of " << cnode->DebugString() << ", flags: " << (*flags);
// Keep the node not to release before we purify its abstract.
specializer->sequence_abstract_list().emplace_back(std::pair(new_sequence_abs, old_input));
cnode->set_input(index, new_input);
@ -575,11 +599,6 @@ void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) {
for (size_t i = 0; i < (*flags).size(); ++i) {
auto old_input = cnode->input(i + 1);
if (!(*flags)[i]) {
// Keep the node not to release before we purify its abstract.
if (IsPrimitiveCNode(old_input, prim::kPrimMakeTuple) || IsPrimitiveCNode(old_input, prim::kPrimMakeList)) {
specializer_->sequence_nodes_replaced_list().emplace_back(old_input);
}
auto zero_value = NewValueNode(MakeValue(0));
zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0)));
inputs.emplace_back(zero_value);
@ -906,9 +925,9 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &c
const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
MS_EXCEPTION_IF_NULL(eval);
auto cache_iter = evalcaches_.find(eval);
if (cache_iter == evalcaches_.end()) {
evalcaches_[eval] = eval->evaluator_cache_mgr();
auto cache_iter = eval_cache_.find(eval);
if (cache_iter == eval_cache_.end()) {
eval_cache_[eval] = eval->evaluator_cache_mgr();
return eval->evaluator_cache_mgr();
}
return cache_iter->second;
@ -918,11 +937,11 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
const EvaluatorPtr &eval) {
MS_EXCEPTION_IF_NULL(eval);
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
EvalResultPtr ret = nullptr;
EvalResultPtr res = nullptr;
AbstractBasePtrList broaded_argvals;
std::vector<AbstractBasePtrList> args_vector;
auto eval_cache_iter = evalcaches_.find(eval);
if (eval_cache_iter == evalcaches_.end()) {
auto eval_cache_iter = eval_cache_.find(eval);
if (eval_cache_iter == eval_cache_.end()) {
MS_LOG(EXCEPTION) << "Evaluator:" << eval->ToString() << " not exist in cache.";
}
auto &origin_eval_cache = eval_cache_iter->second->GetCache();
@ -944,13 +963,13 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
joined_argvals = abstract::AbstractJoin(joined_argvals, args_vector[i]);
}
MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
const auto joined_eval_result = origin_eval_cache.get(joined_argvals);
if (joined_eval_result != nullptr) {
MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
real->SetValue(joined_argvals, joined_eval_result);
evalcaches_[eval] = real;
eval_cache_[eval] = real;
return std::make_pair(joined_argvals, joined_eval_result->abstract());
} else {
ConfigPtrList args_conf_list;
@ -958,11 +977,11 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
[](const AbstractBasePtr &v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
MS_LOG(WARNING) << "Cannot find joined argvals in cache, run with broaded argsvals: " << broaded_argvals.size()
<< ", " << ::mindspore::ToString(broaded_argvals);
ret = eval->SingleRun(engine_, args_conf_list, nullptr);
MS_EXCEPTION_IF_NULL(ret);
real->SetValue(broaded_argvals, ret);
evalcaches_[eval] = real;
return std::make_pair(broaded_argvals, ret->abstract());
res = eval->SingleRun(engine_, args_conf_list, nullptr);
MS_EXCEPTION_IF_NULL(res);
real->SetValue(broaded_argvals, res);
eval_cache_[eval] = real;
return std::make_pair(broaded_argvals, res->abstract());
}
}
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
@ -1132,6 +1151,19 @@ SpecializeStatusCode FuncGraphSpecializer::AcquireUniqueEvalVal(const AbstractFu
*res = BuildFromBroadedArgsVal(eval);
if (!res->first.empty()) {
MS_LOG(DEBUG) << "Build for generalized argvals successfully.";
// Synchronize the new evaluated abstract with the abstract from common evaluating routine.
MS_EXCEPTION_IF_NULL(res->second);
auto new_sequence_abs = dyn_cast<abstract::AbstractSequence>(res->second);
if (new_sequence_abs != nullptr) {
// Just synchronize with the first one.
auto &first_choice = choices.begin()->second;
MS_EXCEPTION_IF_NULL(first_choice);
MS_EXCEPTION_IF_NULL(first_choice->abstract());
auto old_sequence_abs = dyn_cast<abstract::AbstractSequence>(first_choice->abstract());
if (old_sequence_abs != nullptr) {
SynchronizeSequenceElementsUseFlagsRecursively(old_sequence_abs, new_sequence_abs);
}
}
return kSpecializeSuccess;
}
MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism, "

View File

@ -70,8 +70,6 @@ class ProgramSpecializer {
std::vector<std::pair<AbstractSequencePtr, AnfNodePtr>> &sequence_abstract_list() { return sequence_abstract_list_; }
std::vector<AnfNodePtr> &sequence_nodes_replaced_list() { return sequence_nodes_replaced_list_; }
private:
std::shared_ptr<AnalysisEngine> engine_;
mindspore::HashSet<AnfNodePtr> seen_;
@ -81,8 +79,6 @@ class ProgramSpecializer {
AnalysisContextPtr top_context_;
// The list to purify tuple/list elements.
std::vector<std::pair<AbstractSequencePtr, AnfNodePtr>> sequence_abstract_list_;
// The list to hold the weak node ptr before purify abstract.
std::vector<AnfNodePtr> sequence_nodes_replaced_list_;
// Map for unspecialized abstract function to specialized abstract;
std::unordered_map<AbstractFunctionPtr, AbstractBasePtr, AbstractFunctionHasher, AbstractFunctionEqual>
specialized_abs_map_;
@ -109,7 +105,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
ClonerPtr cloner_;
std::vector<AnfNodePtr> todo_;
mindspore::HashSet<AnfNodePtr> marked_;
mindspore::HashMap<EvaluatorPtr, EvaluatorCacheMgrPtr> evalcaches_;
mindspore::HashMap<EvaluatorPtr, EvaluatorCacheMgrPtr> eval_cache_;
void FirstPass();
void SecondPass();

View File

@ -563,11 +563,11 @@ bool AbstractSequence::PurifyElements() {
}
if (elements_use_flags_ptr == nullptr) {
if (not_free_node == nullptr) {
MS_LOG(INFO) << "Check if all sequence nodes are released, or none elements use flags in them. nodes size: "
<< sequence_nodes_->size();
MS_LOG(ERROR) << "Check if all sequence nodes are released, or none elements use flags in them. nodes size: "
<< sequence_nodes_->size();
} else {
MS_LOG(INFO) << "Check if none elements use flags in sequence ndoes. one of node: "
<< not_free_node->DebugString();
MS_LOG(ERROR) << "Check if none elements use flags in sequence ndoes. one of node: "
<< not_free_node->DebugString();
}
return false;
}