!20850 fix issues of blue team

Merge pull request !20850 from lanzhineng/infer_optv3
This commit is contained in:
i-robot 2021-07-27 01:27:08 +00:00 committed by Gitee
commit b02e5c86e0
2 changed files with 8 additions and 12 deletions

View File

@ -67,7 +67,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
}
}
MS_EXCEPTION_IF_NULL(out_conf);
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
@ -240,6 +240,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
MS_EXCEPTION_IF_NULL(out_conf);
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
@ -254,9 +255,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
}
scope = out_conf->node()->scope();
ScopeGuard scope_guard(scope);
FuncGraphPtr func_graph = out_conf->node()->func_graph();

View File

@ -455,11 +455,11 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
} // namespace
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
MS_EXCEPTION_IF_NULL(func);
auto primitive = func->prim();
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
evaluators_[func] = evaluator;
@ -467,11 +467,11 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
MS_EXCEPTION_IF_NULL(func);
std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
evaluators_[func] = func_graph_evaluator;
@ -479,11 +479,12 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbs
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
auto inf_pair = evaluators_.find(func);
if (inf_pair != evaluators_.end()) {
return inf_pair->second;
}
MS_EXCEPTION_IF_NULL(func);
std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->GetScope());
evaluators_[func] = evaluator;
@ -558,11 +559,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
if (func->tracking_id() != nullptr) {
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
}
MS_EXCEPTION_IF_NULL(func);
// protect the constructors
static std::recursive_mutex constructors_mutex;
// std::lock_guard<std::recursive_mutex> lock(constructors_mutex);
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
func->isa<abstract::FuncGraphAbstractClosure>()) {
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
@ -936,6 +933,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
MS_EXCEPTION_IF_NULL(out_conf->node());
auto possible_parent_fg = out_conf->node()->func_graph();
for (auto eval : evaluators) {
MS_EXCEPTION_IF_NULL(eval);
(void)SetUndeterminedFlag(eval, possible_parent_fg);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
@ -944,7 +942,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
if (it == eval_trace_.rend()) {
eval_trace_.push_back(current_inf);
MS_EXCEPTION_IF_NULL(eval);
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
auto eval_abstract = eval_result->abstract();
MS_EXCEPTION_IF_NULL(eval_abstract);