Synchronize the inputs abstract sequence node info. before save.

This commit is contained in:
Zhang Qinghua 2022-01-13 21:57:52 +08:00
parent a5c07794f6
commit 4642b96624
6 changed files with 47 additions and 12 deletions

View File

@ -413,6 +413,7 @@ class AnalysisResultCacheMgr {
return instance;
}
void Clear();
const AnalysisConfigResultCache &GetCache() { return cache_; }
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
void InitSwitchValue(const AnfNodeConfigPtr &conf);

View File

@ -240,8 +240,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
MS_LOG(DEBUG) << GetInferThread() << "Set parameter[" << i << "] for " << fg->ToString()
<< ", conf: " << conf->ToString() << ", arg: " << arg->ToString();
MS_LOG(DEBUG) << GetInferThread() << ", Save argument[" << i << "] result for " << fg->ToString()
<< ", NodeConfig: " << conf->ToString() << ", result: " << arg << "/" << arg->ToString();
}
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
<< ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
@ -403,10 +403,12 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
auto new_sequence = dyn_cast<AbstractTuple>(args_spec_list[i]);
auto old_sequence = dyn_cast<AbstractTuple>(iter->first[i]);
if (old_sequence != nullptr && new_sequence != nullptr) {
MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags, old_sequence: " << old_sequence->ToString()
MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags for NodeConfig: " << out_conf->ToString()
<< ", old_sequence: " << old_sequence->ToString()
<< ", new_sequence: " << new_sequence->ToString();
SynchronizeSequenceNodesElementsUseFlags(old_sequence->sequence_nodes(), new_sequence->sequence_nodes());
MS_LOG(DEBUG) << "After synchronize sequence nodes use flags, old_sequence: " << old_sequence->ToString()
MS_LOG(DEBUG) << "After synchronize sequence nodes use flags for NodeConfig: " << out_conf->ToString()
<< ", old_sequence: " << old_sequence->ToString()
<< ", new_sequence: " << new_sequence->ToString();
}
}

View File

@ -616,7 +616,9 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(input_conf);
replace_node->set_abstract(abs);
MS_LOG(DEBUG) << "Set replaced: " << replace_node->DebugString() << ", to abstract: " << abs->ToString();
MS_LOG(DEBUG) << "Set replaced input[" << i << "]: " << replace_node->DebugString()
<< ", NodeConfig: " << input_conf->ToString() << ", result: " << abs.get() << "/"
<< abs->ToString();
} else {
MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
<< ", abs: " << abs->ToString() << ", replace_node: " << replace_node->DebugString();

View File

@ -116,7 +116,10 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
const auto &arg_abs = args_abs_list[i];
const auto &node = fg->parameters()[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, new_context, new_context->func_graph());
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg_abs, nullptr));
auto result = std::make_shared<EvalResult>(arg_abs, nullptr);
MS_LOG(DEBUG) << "Save argument[" << i << "] result, NodeConfig: " << conf->ToString()
<< ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
engine->SaveEvalResultInCache(conf, result);
}
// Create a new stack frame and set arguments for it.
@ -181,9 +184,11 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
// Continue saving node's result for parent func graph.
auto &current_node = NextNode();
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_, current_context_->func_graph());
MS_LOG(DEBUG) << "current_node: " << current_node->DebugString()
<< ", current_context_: " << current_context_->ToString()
<< ", Save result, NodeConfig: " << node_conf->ToString() << ", result: " << result->abstract().get()
<< "/" << result->abstract()->ToString();
engine->SaveEvalResultInCache(node_conf, result);
// Leave the call CNode.

View File

@ -172,6 +172,29 @@ void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const E
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(result);
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
auto iter = cache_mgr.GetCache().find(conf);
if (iter != cache_mgr.GetCache().end()) {
MS_LOG(DEBUG) << "Found previous result for NodeConfig: " << conf->ToString()
<< ", result: " << iter->second->abstract().get() << "/" << iter->second->abstract()->ToString();
// Update sequence nodes info, if matched in cache.
static const auto eliminate_unused_element = common::GetEnv("MS_DEV_ELIMINATE_SEQUENCE_UNUSED_ELEMENT");
static const auto enable_eliminate_unused_element = (eliminate_unused_element == "1");
if (enable_eliminate_unused_element) {
auto new_sequence = dyn_cast<AbstractTuple>(result->abstract());
auto old_sequence = dyn_cast<AbstractTuple>(iter->second->abstract());
if (old_sequence != nullptr && new_sequence != nullptr) {
MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags for NodeConfig: " << conf->ToString()
<< ", old_sequence: " << old_sequence->ToString()
<< ", new_sequence: " << new_sequence->ToString();
SynchronizeSequenceNodesElementsUseFlags(old_sequence->sequence_nodes(), new_sequence->sequence_nodes());
MS_LOG(DEBUG) << "After synchronize sequence nodes use flags for NodeConfig: " << conf->ToString()
<< ", old_sequence: " << old_sequence->ToString()
<< ", new_sequence: " << new_sequence->ToString();
}
}
}
MS_LOG(DEBUG) << "Save result for NodeConfig: " << conf->ToString() << ", result: " << result->abstract().get() << "/"
<< result->abstract()->ToString();
cache_mgr.SetValue(conf, result);
// Set intermediate abstract value.
@ -195,6 +218,8 @@ EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
auto result = cache_mgr.GetValue(conf);
if (result != nullptr) {
MS_LOG(DEBUG) << "Evaluate cache found for NodeConfig: " << conf->ToString()
<< ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
return result;
}
MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
@ -202,8 +227,8 @@ EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &
if (result == nullptr) {
MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
}
MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString()
<< ", result: " << result->abstract().get() << ", " << result->abstract()->ToString();
MS_LOG(DEBUG) << "Evaluate node on demand for NodeConfig: " << conf->ToString()
<< ", result: " << result->abstract().get() << "/" << result->abstract()->ToString();
SaveEvalResultInCache(conf, result);
return result;
}

View File

@ -332,8 +332,8 @@ void SynchronizeSequenceNodesElementsUseFlagsInner(const AnfNodeWeakPtrList &seq
unique_flags = current_flags;
} else {
if (current_count > 1 && latter_count > 1) {
MS_LOG(INFO) << "Allow only one side has more than one use count. count: " << current_count << ", "
<< latter_count;
MS_LOG(DEBUG) << "Allow only one side has more than one use count. count: " << current_count << ", "
<< latter_count;
}
if (current_count > latter_count) {
unique_flags = current_flags;