diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc index d41df36e657..05e61a1df68 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc @@ -15,26 +15,39 @@ */ #include "pipeline/jit/static_analysis/async_eval_result.h" -#include +#include #include "utils/symbolic.h" #include "debug/common.h" #include "pipeline/jit/base.h" #include "utils/utils.h" -#include "abstract/utils.h" namespace mindspore { namespace abstract { HealthPointMgr HealthPointMgr::instance_; -void HealthPointMgr::Clear() { point_ = 1; } +void HealthPointMgr::Clear() { + MS_LOG(DEBUG) << " Point = " << point_; + point_ = 1; +} void HealthPointMgr::HandleException() { + // Just record the first exception information. + if (!StaticAnalysisException::Instance().HasException()) { + std::ostringstream oss; + trace::GetEvalStackInfo(oss); + if (!oss.str().empty()) { + MS_LOG(ERROR) << oss.str(); + } + StaticAnalysisException::Instance().SetException(); + } + // Free all the locks. Let all the threads continue to run. std::lock_guard lock(lock_); for (auto &item : asyncAbstractList_) { item->SetRunable(); } asyncAbstractList_.clear(); } + void HealthPointMgr::SetNextRunable() { std::lock_guard lock(lock_); if (asyncAbstractList_.empty()) { @@ -46,9 +59,9 @@ void HealthPointMgr::SetNextRunable() { [](const auto &item) { return item->HasResult(); }); if (it == asyncAbstractList_.end()) { // Enter endless loop if there is not ready result. - MS_LOG(EXCEPTION) << "Enter endless loop. Please check the code. point = " << HealthPointMgr::GetInstance().point() - << " Called times : " << asyncAbstractList_.front()->count(); + MS_LOG(EXCEPTION) << "Enter endless loop. There is not more node that can been evaluated. Please check the code."; } + // Push back the not ready async. asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it); asyncAbstractList_.erase(asyncAbstractList_.begin(), it); @@ -65,8 +78,10 @@ void AnalysisResultCacheMgr::Clear() { cache_.clear(); switch_cache_.clear(); todo_.clear(); + waiting_.clear(); } +// The thread id format is XXXX.YYYY.ZZZZ thread_local static std::string local_threadid; void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) { std::ostringstream buffer; @@ -76,7 +91,7 @@ void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) { std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; } -void AnalysisResultCacheMgr::PushTowait(std::future &&future) { +void AnalysisResultCacheMgr::PushToWait(std::future &&future) { std::lock_guard lock(lock_); waiting_.emplace_back(std::move(future)); } @@ -94,6 +109,7 @@ void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) { switch_cache_.set(conf, async_eval_result); } } + AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) { // don't call lock_.lock(). switch_cache is protected. and it waits for result. AsyncAbstractPtr async_eval_result = switch_cache_.get(conf); @@ -125,7 +141,7 @@ AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &c return nullptr; } -void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr arg) { +void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) { MS_EXCEPTION_IF_NULL(conf); if (arg == nullptr) { MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr"; @@ -159,6 +175,14 @@ void AnalysisResultCacheMgr::Todo() { while (!todo_.empty()) { AnfNodeConfigPtr conf = todo_.front(); todo_.pop_front(); + if (GetValue(conf) == nullptr) { + MS_LOG(WARNING) << conf->node()->ToString() << " not in globleCache"; + continue; + } + if (TryGetSwitchValue(conf) == nullptr) { + MS_LOG(WARNING) << conf->node()->ToString() << " not in switchCache"; + continue; + } if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) { MS_LOG(WARNING) << " Switch Value is not eq. " << " switchCache: " << TryGetSwitchValue(conf)->ToString() @@ -172,7 +196,6 @@ void AnalysisResultCacheMgr::Wait() { // Check all the async to finish. HealthPointScopedDrop hpCheck; while (true) { - StaticAnalysisException::Instance().CheckException(); lock_.lock(); if (waiting_.empty()) { lock_.unlock(); @@ -188,6 +211,7 @@ void AnalysisResultCacheMgr::Wait() { if (IS_OUTPUT_ON(DEBUG)) { Todo(); } + MS_LOG(INFO) << "Infer finished."; } std::string ArgsToString(const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h index c24d45a22c2..6889e9edaba 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h @@ -51,7 +51,7 @@ class HealthPointMgr { if (point_ == 0) { SetNextRunable(); } else if (point_ < 0) { - MS_LOG(EXCEPTION) << "There is something wrong."; + MS_LOG(WARNING) << "There is something wrong. point = " << point_; } } @@ -66,7 +66,7 @@ class HealthPointMgr { ++point_; } - int point() { return point_; } + int point() const { return point_; } void Add2Schedule(const AsyncAbstractPtr &asyncAbastract) { std::lock_guard lock(lock_); @@ -187,48 +187,47 @@ class AsyncAbstract : public std::enable_shared_from_this { ~AsyncAbstract() = default; // Wait AbstractBasePtr GetResult() { - StaticAnalysisException::Instance().CheckException(); + static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance(); + static StaticAnalysisException &exceptionMgr = StaticAnalysisException::Instance(); + exceptionMgr.CheckException(); std::unique_lock lock(lock_); while (true) { ++count_; // The point should be dropped if it can't run. It will be added when it can run. bool hasDropPoint = false; if (!runable_) { - HealthPointMgr::GetInstance().DropPoint(); + healthPointMgr.DropPoint(); hasDropPoint = true; } - MS_LOG(DEBUG) << this << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0); condition_var_.wait(lock, [this] { return runable_; }); + if (hasDropPoint) { + healthPointMgr.AddPoint(); + } MS_LOG(DEBUG) << this << " continue runable: " << runable_ << " result: " << (result_ ? result_.get() : 0); - StaticAnalysisException::Instance().CheckException(); + + exceptionMgr.CheckException(); runable_ = false; if (result_ != nullptr) { - if (hasDropPoint) { - HealthPointMgr::GetInstance().AddPoint(); - } MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0); return result_; } // Push to list - HealthPointMgr::GetInstance().Add2Schedule(shared_from_this()); - if (hasDropPoint) { - HealthPointMgr::GetInstance().AddPoint(); - } + healthPointMgr.Add2Schedule(shared_from_this()); // Notify the next asyncAbastract to run. - HealthPointMgr::GetInstance().SetNextRunable(); + healthPointMgr.SetNextRunable(); MS_LOG(DEBUG) << this << " SetNextRunable " << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0) - << " point:" << HealthPointMgr::GetInstance().point(); + << " point:" << healthPointMgr.point(); } - return nullptr; } + void SetRunable() { MS_LOG(DEBUG) << this << " Runable."; runable_ = true; condition_var_.notify_one(); } - int count() { return count_; } + int count() const { return count_; } bool HasResult() { return result_ != nullptr; } // Not wait @@ -287,7 +286,7 @@ class AnalysisResultCacheMgr { inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); } // Wait for async Eval(conf) to finish. void Wait(); - void PushTowait(std::future &&future); + void PushToWait(std::future &&future); void PushTodo(const AnfNodeConfigPtr &conf); void Todo(); static void UpdateCaller(const std::string &caller); @@ -295,7 +294,7 @@ class AnalysisResultCacheMgr { void InitSwitchValue(const AnfNodeConfigPtr &conf); AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf); AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf); - void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr vale); + void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale); private: using AnalysisConfigAsyncResultMap = diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index da0c80b0cbd..fdb1e619160 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -509,9 +509,15 @@ EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrLis } EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { - auto result = this->Run(engine, args_conf_list, out_conf); - + EvalResultPtr result; + try { + result = this->Run(engine, args_conf_list, out_conf); + } catch (const std::exception &e) { + MS_LOG(WARNING) << "Eval " << ToString() << " throw exception."; + HealthPointMgr::GetInstance().HandleException(); + } AnalysisResultCacheMgr::GetInstance().Wait(); + StaticAnalysisException::Instance().CheckException(); return result; } } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 6b591c71080..7466cf55c7f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -120,32 +120,37 @@ bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeCon AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { StaticAnalysisException::Instance().ClearException(); HealthPointMgr::GetInstance().Clear(); - ConfigPtrList args_conf_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - MS_EXCEPTION_IF_NULL(func_graph_manager_); - func_graph_manager_->AddFuncGraph(func_graph); - - root_func_graph_ = func_graph; - - AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); - - // Running the analyzer. - ResetFunctionCallDepth(); - ResetStackFrameDepth(); - AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); - MS_EXCEPTION_IF_NULL(root_context); - MS_EXCEPTION_IF_NULL(root_context->func_graph()); - AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; - AnalysisResult result; - MS_EXCEPTION_IF_NULL(output_conf); - result.inferred = output_conf->ObtainEvalResult(); - result.context = root_context; + try { + ConfigPtrList args_conf_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + MS_EXCEPTION_IF_NULL(func_graph_manager_); + func_graph_manager_->AddFuncGraph(func_graph); + root_func_graph_ = func_graph; + + AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); + + // Running the analyzer. + ResetFunctionCallDepth(); + ResetStackFrameDepth(); + AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); + MS_EXCEPTION_IF_NULL(root_context); + MS_EXCEPTION_IF_NULL(root_context->func_graph()); + AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; + + MS_EXCEPTION_IF_NULL(output_conf); + result.inferred = output_conf->ObtainEvalResult(); + result.context = root_context; + } catch (const std::exception &e) { + MS_LOG(WARNING) << "Eval " << func_graph->ToString() << " threw exception."; + HealthPointMgr::GetInstance().HandleException(); + } AnalysisResultCacheMgr::GetInstance().Wait(); + StaticAnalysisException::Instance().CheckException(); return result; } @@ -374,6 +379,8 @@ void AnalysisEngine::ClearEvaluatorCache() { MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr()); evaluator->evaluator_cache_mgr()->Clear(); } + // Release Exception to avoid hup at exit. + StaticAnalysisException::Instance().ClearException(); } void AnalysisEngine::Clear() { @@ -789,40 +796,38 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar AsyncAbstractPtr async_run_flag) { AnalysisResultCacheMgr::UpdateCaller(caller); try { + trace::ClearTraceStack(); + // Wait for Signal to run MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting."; (void)async_run_flag->GetResult(); MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running."; - // Acquire GIL - py::gil_scoped_acquire pyGuard; - trace::ClearTraceStack(); - auto result = eval->Run(engine, args_conf_list, out_conf); + // Acquire GIL for eval to callback python. + EvalResultPtr result; + { + py::gil_scoped_acquire pyGuard; + result = eval->Run(engine, args_conf_list, out_conf); + } MS_EXCEPTION_IF_NULL(result); MS_EXCEPTION_IF_NULL(result->abstract()); // Broaden the result of switch(c,t,f)() auto broadAbstract = result->abstract()->Broaden(); - // Let main thread to continue. + // Notify the thread of waiting for switch node and the main thread to continue. AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract); async_result_branch->JoinResult(broadAbstract); async_result_main->JoinResult(broadAbstract); + // Health Point will be drop when thread exits. + HealthPointMgr::GetInstance().DropPoint(); MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString() << " asyncResult address = " << async_result_branch.get() << " value = " << async_result_branch->TryGetResult()->ToString(); - // Decrease infer thread. - HealthPointMgr::GetInstance().DropPoint(); } catch (const std::exception &e) { - std::ostringstream oss; - oss << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception."; - trace::GetEvalStackInfo(oss); - if (!oss.str().empty()) { - MS_LOG(ERROR) << oss.str(); - } - auto abstractErrPtr = std::make_shared(std::make_shared(oss.str()), out_conf->node()); + MS_LOG(WARNING) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception."; + auto abstractErrPtr = std::make_shared(std::make_shared("Exception"), out_conf->node()); AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr); async_result_main->JoinResult(abstractErrPtr); - StaticAnalysisException::Instance().SetException(); HealthPointMgr::GetInstance().HandleException(); } } @@ -830,61 +835,54 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { - // Release GIL; + MS_EXCEPTION_IF_NULL(out_conf); + MS_EXCEPTION_IF_NULL(out_conf->node()); + static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance(); + static AnalysisResultCacheMgr &resultCacheMgr = AnalysisResultCacheMgr::GetInstance(); + + // Release GIL for C++ py::gil_scoped_release infer_gil_release; // Wait for the last switch node to finish. MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString(); - auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf); + auto eval_result = resultCacheMgr.GetSwitchValue(out_conf); if (eval_result == nullptr) { MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString(); - AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf); + resultCacheMgr.InitSwitchValue(out_conf); } else { - if (eval_result->isa() || eval_result->isa()) { - MS_LOG(ERROR) << "Eval " << out_conf->node()->ToString() << " threw exception."; - StaticAnalysisException::Instance().CheckException(); - } return std::make_shared(eval_result, nullptr); } - MS_EXCEPTION_IF_NULL(out_conf); - MS_EXCEPTION_IF_NULL(out_conf->node()); auto possible_parent_fg = out_conf->node()->func_graph(); - // Eval result of the branches and main. - AsyncAbstractPtr asyncResult_main = std::make_shared(); std::string threadId = AnalysisResultCacheMgr::GetThreadid(); + // Eval result of the main. + AsyncAbstractPtr asyncResult_main = std::make_shared(); + // Eval result of the branches std::vector branchAsyncResults; for (auto &evaluator : evaluators) { + SetUndeterminedFlag(evaluator, possible_parent_fg); AsyncAbstractPtr branchAsyncResult = std::make_shared(); // Control the order to run. AsyncAbstractPtr asyncRunOrder = std::make_shared(); - SetUndeterminedFlag(evaluator, possible_parent_fg); + // Add point to the async thread. + healthPointMgr.AddPoint(); MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString(); - // Add point to infer thread - HealthPointMgr::GetInstance().AddPoint(); auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf, threadId, branchAsyncResult, asyncResult_main, asyncRunOrder); // Wait for async threads to finish. - AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future)); + resultCacheMgr.PushToWait(std::move(future)); // Push to list of running loop asyncRunOrder->JoinResult(std::make_shared(1)); - HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order + healthPointMgr.Add2Schedule(asyncRunOrder); // Activate order branchAsyncResults.emplace_back(std::move(branchAsyncResult)); } MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString() - << " or " << evaluators[1]->ToString(); - HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order + << " or " << evaluators[1]->ToString() << "..."; + healthPointMgr.Add2Schedule(asyncResult_main); // Third order auto firstResult = asyncResult_main->GetResult(); - if (firstResult == nullptr || firstResult->isa()) { - MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString() - << " Please check the code if there are recursive functions."; - } - if (firstResult->isa()) { - MS_LOG(DEBUG) << "async " << out_conf->node()->ToString() << " threw exception."; - StaticAnalysisException::Instance().CheckException(); - } + MS_EXCEPTION_IF_NULL(firstResult); MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = " << firstResult->ToString(); @@ -893,19 +891,15 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve if (NeedWaitForBranches(firstResult)) { for (size_t i = 0; i < len; ++i) { MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString(); - HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]); + healthPointMgr.Add2Schedule(branchAsyncResults[i]); auto result = branchAsyncResults[i]->GetResult(); - if (result == nullptr || result->isa()) { - MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out." - << " Please check the code if there is recursive function."; - } + MS_EXCEPTION_IF_NULL(result); out_specs.push_back(result); } } else { - // Next time to get the result of branches. - HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); + // Give one more chance to wait for the result of the branches. + healthPointMgr.Add2Schedule(asyncResult_main); (void)asyncResult_main->GetResult(); - for (size_t i = 0; i < len; ++i) { // Not wait to get the result of branch. auto result = branchAsyncResults[i]->TryGetResult(); diff --git a/tests/st/control/test_recrusive_fun.py b/tests/st/control/test_recrusive_fun.py index 022c98ae775..a69672f0b52 100644 --- a/tests/st/control/test_recrusive_fun.py +++ b/tests/st/control/test_recrusive_fun.py @@ -51,6 +51,38 @@ def fr(x): return z +@ms_function +def f_pythonerr(x): + if x > 0: + return f_pythonerr(x - 1) + return NOT_DEF + + +def test_python_error(): + context.set_context(mode=context.GRAPH_MODE) + x = Tensor([5], mstype.int32) + try: + f_pythonerr(x) + except NameError as e: + assert 'not defined' in str(e) + + +@ms_function +def f_recrusive_endless(x): + if x > 0: + return f_recrusive_endless(x - 1) + return f_recrusive_endless(x + 1) + + +def test_recrusive_endless(): + context.set_context(mode=context.GRAPH_MODE) + x = Tensor([5], mstype.int32) + try: + f_recrusive_endless(x) + except RuntimeError as e: + assert 'endless loop' in str(e) + + def test_endless(): context.set_context(mode=context.GRAPH_MODE) x = Tensor([5], mstype.int32) @@ -60,6 +92,22 @@ def test_endless(): assert 'endless loop' in str(e) +@ms_function +def f_ok(x): + if x > 0: + return f_ok(x - 1) + 1 + return ONE + + +@pytest.mark.skip(reason="backend is not supported yet") +def test_f_ok(): + context.set_context(mode=context.GRAPH_MODE) + x = Tensor([3], mstype.int32) + ret = f_ok(x) + expect = Tensor([4], mstype.int32) + assert ret == expect + + @pytest.mark.skip(reason="backend is not supported yet") def test_recrusive_fun(): context.set_context(mode=context.GRAPH_MODE)