forked from mindspore-Ecosystem/mindspore
Remove the repeats of inferring and optimize the sorting routine.
Total Renormalizes: ----- 69.05010 --> 62.28941 -----
This commit is contained in:
parent
268d358a1d
commit
87714b3c7f
|
@ -17,6 +17,7 @@
|
|||
#include "pipeline/static_analysis/evaluator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "pipeline/static_analysis/utils.h"
|
||||
|
@ -61,6 +62,29 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &
|
|||
return context;
|
||||
}
|
||||
|
||||
static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
|
||||
std::vector<AnfNodePtr> sorted_nodes;
|
||||
std::unordered_set<AnfNodePtr> checked_cnodes;
|
||||
std::size_t index = 0;
|
||||
sorted_nodes.emplace_back(ret_node);
|
||||
while (index < sorted_nodes.size()) {
|
||||
auto current = sorted_nodes[index];
|
||||
index++;
|
||||
MS_EXCEPTION_IF_NULL(current);
|
||||
if (current->isa<CNode>()) {
|
||||
auto &inputs = current->cast<CNodePtr>()->inputs();
|
||||
for (auto it = inputs.begin(); it != inputs.end(); it++) {
|
||||
AnfNodePtr input = *it;
|
||||
if (input != nullptr && input->isa<CNode>() && checked_cnodes.find(input) == checked_cnodes.end()) {
|
||||
sorted_nodes.emplace_back(input);
|
||||
(void)checked_cnodes.insert(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return sorted_nodes;
|
||||
}
|
||||
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
@ -86,20 +110,20 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab
|
|||
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
|
||||
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
|
||||
const std::vector<AnfNodePtr> &all_nodes = TopoSort(func_node);
|
||||
for (const auto &node : all_nodes) {
|
||||
AbstractBasePtr ret_base = nullptr;
|
||||
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
|
||||
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
|
||||
const auto &node = *it;
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
|
||||
AbstractBasePtr base = engine->GetEvaluatedValue(node_conf);
|
||||
ret_base = engine->GetEvaluatedValue(node_conf);
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
|
||||
<< ", abstract: " << base->ToString();
|
||||
<< ", abstract: " << ret_base->ToString();
|
||||
}
|
||||
|
||||
AnfNodeConfigPtr ret_conf = engine->MakeConfig(func_node, graph_context_);
|
||||
AbstractBasePtr base = engine->GetEvaluatedValue(ret_conf);
|
||||
MS_EXCEPTION_IF_NULL(base);
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << base->ToString();
|
||||
return base;
|
||||
MS_EXCEPTION_IF_NULL(ret_base);
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString();
|
||||
return ret_base;
|
||||
}
|
||||
|
||||
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
|
|
Loading…
Reference in New Issue