forked from mindspore-Ecosystem/mindspore
!20850 fix issues of blue team
Merge pull request !20850 from lanzhineng/infer_optv3
This commit is contained in:
commit
b02e5c86e0
|
@ -67,7 +67,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
@ -240,6 +240,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
@ -254,9 +255,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
|
||||
|
||||
ScopePtr scope = kDefaultScope;
|
||||
if (out_conf != nullptr) {
|
||||
scope = out_conf->node()->scope();
|
||||
}
|
||||
scope = out_conf->node()->scope();
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
||||
FuncGraphPtr func_graph = out_conf->node()->func_graph();
|
||||
|
|
|
@ -455,11 +455,11 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
} // namespace
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto primitive = func->prim();
|
||||
auto evaluator = GetPrimEvaluator(primitive, shared_from_this());
|
||||
evaluators_[func] = evaluator;
|
||||
|
@ -467,11 +467,11 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PrimitiveAbs
|
|||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
std::shared_ptr<FuncGraphEvaluator> func_graph_evaluator =
|
||||
std::make_shared<FuncGraphEvaluator>(func->func_graph(), func->context());
|
||||
evaluators_[func] = func_graph_evaluator;
|
||||
|
@ -479,11 +479,12 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<FuncGraphAbs
|
|||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<MetaFuncGraphAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto inf_pair = evaluators_.find(func);
|
||||
if (inf_pair != evaluators_.end()) {
|
||||
return inf_pair->second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
|
||||
std::shared_ptr<MetaFuncGraphEvaluator> evaluator =
|
||||
std::make_shared<MetaFuncGraphEvaluator>(func->meta_func_graph(), func->GetScope());
|
||||
evaluators_[func] = evaluator;
|
||||
|
@ -558,11 +559,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
if (func->tracking_id() != nullptr) {
|
||||
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
|
||||
// protect the constructors
|
||||
static std::recursive_mutex constructors_mutex;
|
||||
// std::lock_guard<std::recursive_mutex> lock(constructors_mutex);
|
||||
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
|
||||
func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
|
||||
|
@ -936,6 +933,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
for (auto eval : evaluators) {
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
(void)SetUndeterminedFlag(eval, possible_parent_fg);
|
||||
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
|
@ -944,7 +942,6 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
|
||||
if (it == eval_trace_.rend()) {
|
||||
eval_trace_.push_back(current_inf);
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
auto eval_abstract = eval_result->abstract();
|
||||
MS_EXCEPTION_IF_NULL(eval_abstract);
|
||||
|
|
Loading…
Reference in New Issue