!219 Remove the repeats of inferring and optimize the sorting routine.

Merge pull request !219 from ZhangQinghua/master
This commit is contained in:
mindspore-ci-bot 2020-04-11 20:38:20 +08:00 committed by Gitee
commit 09a49e0d69
1 changed files with 33 additions and 9 deletions

View File

@ -17,6 +17,7 @@
#include "pipeline/static_analysis/evaluator.h"
#include <algorithm>
#include <unordered_set>
#include "ir/func_graph_cloner.h"
#include "pipeline/static_analysis/utils.h"
@ -61,6 +62,29 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &
return context;
}
static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
std::vector<AnfNodePtr> sorted_nodes;
std::unordered_set<AnfNodePtr> checked_cnodes;
std::size_t index = 0;
sorted_nodes.emplace_back(ret_node);
while (index < sorted_nodes.size()) {
auto current = sorted_nodes[index];
index++;
MS_EXCEPTION_IF_NULL(current);
if (current->isa<CNode>()) {
auto &inputs = current->cast<CNodePtr>()->inputs();
for (auto it = inputs.begin(); it != inputs.end(); it++) {
AnfNodePtr input = *it;
if (input != nullptr && input->isa<CNode>() && checked_cnodes.find(input) == checked_cnodes.end()) {
sorted_nodes.emplace_back(input);
(void)checked_cnodes.insert(input);
}
}
}
}
return sorted_nodes;
}
AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg);
@ -86,20 +110,20 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
const std::vector<AnfNodePtr> &all_nodes = TopoSort(func_node);
for (const auto &node : all_nodes) {
AbstractBasePtr ret_base = nullptr;
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
AbstractBasePtr base = engine->GetEvaluatedValue(node_conf);
ret_base = engine->GetEvaluatedValue(node_conf);
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
<< ", abstract: " << base->ToString();
<< ", abstract: " << ret_base->ToString();
}
AnfNodeConfigPtr ret_conf = engine->MakeConfig(func_node, graph_context_);
AbstractBasePtr base = engine->GetEvaluatedValue(ret_conf);
MS_EXCEPTION_IF_NULL(base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << base->ToString();
return base;
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString();
return ret_base;
}
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {