diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index 0c5598d2f1e..568effff2e6 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -34,6 +34,7 @@ #include "debug/anf_ir_utils.h" #include "debug/common.h" #include "pipeline/jit/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/async_eval_result.h" #include "utils/log_adapter.h" #include "abstract/abstract_value.h" @@ -177,7 +178,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(engine_); auto cfg = engine_->MakeConfig(node, cur_ctx_); - auto ret = engine_->analysis_cache().GetValue(cfg); + auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg); if (ret == nullptr) { return "Undefined"; } @@ -190,7 +191,7 @@ AbstractBasePtr AnalyzedFuncGraphExporter::GetNodeAbstract(const AnfNodePtr &nod } MS_EXCEPTION_IF_NULL(engine_); auto cfg = engine_->MakeConfig(node, cur_ctx_); - auto ret = engine_->analysis_cache().GetValue(cfg); + auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg); return ret == nullptr ? nullptr : ret->abstract(); } @@ -548,9 +549,9 @@ void GetEvalStackInfo(std::ostringstream &oss) { } // trace the graph evaluator stack -static std::stack> graph_infer_stack; +thread_local static std::stack> graph_infer_stack; // trace the cnode infer debug info -static std::vector cnode_debug_stack{}; +thread_local static std::vector cnode_debug_stack{}; void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { if (eval == nullptr) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index c5d1e9026b8..0aabf8fa4a6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -106,7 +106,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { prim::kPrimMirrorMiniStep); mini_step_allgather_replace_ = MakeSubstitution(std::make_shared(), "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); - virtual_add_elim_ = MakeSubstitution(std::make_shared(), "virtual add", prim::kPrimVirtualAdd); + virtual_add_elim_ = MakeSubstitution(std::make_shared(), "virtual_add", prim::kPrimVirtualAdd); check_bprop_eliminate_ = MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 3c6f4372286..22c9eb334aa 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -30,6 +30,7 @@ #include "pipeline/jit/pass.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/optimizer/ad/dfunctor.h" +#include "pipeline/jit/static_analysis/async_eval_result.h" #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" #include "debug/anf_ir_utils.h" @@ -728,10 +729,12 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: MS_LOG(DEBUG) << PrintArgs(args); ret_value = CompileInner(obj, args, phase, use_vm); } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::string exception_info = GetCompileExceptionInfo(); - if (!exception_info.empty()) { - MS_LOG(ERROR) << exception_info; + if (!StaticAnalysisException::Instance().HasException()) { + // print function call stack info before release + std::string exception_info = GetCompileExceptionInfo(); + if (!exception_info.empty()) { + MS_LOG(ERROR) << exception_info; + } } ReleaseResource(phase); @@ -1281,6 +1284,7 @@ void ClearResAtexit() { ReleaseGeTsd(); parse::python_adapter::ResetPythonScope(); + abstract::AnalysisResultCacheMgr::GetInstance().Clear(); #ifdef ENABLE_DEBUGGER Debugger::GetInstance()->Reset(); #endif diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc new file mode 100644 index 00000000000..5987e629f16 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/async_eval_result.h" +#include +#include "utils/symbolic.h" +#include "debug/common.h" +#include "pipeline/jit/base.h" +#include "utils/utils.h" +#include "abstract/utils.h" + +namespace mindspore { +namespace abstract { +EvalResultPtr AsyncEvalResult::TryGetResult(int ms) { + if (result_ != nullptr || ms == 0) { + return result_; + } + std::unique_lock lock(lock_); + auto time = std::chrono::microseconds(ms); + // Wait for ms. + (void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; }); + return result_; +} + +EvalResultPtr AsyncEvalResult::GetResult() { + if (result_ != nullptr) { + return result_; + } + std::unique_lock lock(lock_); + auto time = std::chrono::seconds(kInferTimeout); + (void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; }); + return result_; +} + +std::string AsyncEvalResult::ToString() { + std::ostringstream buffer; + std::lock_guard lock(lock_); + buffer << (result_ == nullptr ? "NOT SET" : result_->abstract()->ToString()); + return buffer.str(); +} + +void AsyncEvalResult::JoinResult(const EvalResultPtr &result) { + MS_EXCEPTION_IF_NULL(result); + { + std::lock_guard lock(lock_); + result_ = result; + } + condition_var_.notify_all(); +} + +void AnalysisResultCacheMgr::Clear() { + cache_.clear(); + switch_cache_.clear(); + todo_.clear(); +} + +AnalysisResultCacheMgr &AnalysisResultCacheMgr::GetInstance() { + static AnalysisResultCacheMgr instance; + return instance; +} + +void AnalysisResultCacheMgr::DumpCache(const std::string &filename) { + auto path = pipeline::GetSaveGraphsPathName(Common::AddId(filename, ".cache")); + + auto realpath = Common::GetRealPath(path); + if (!realpath.has_value()) { + MS_LOG(ERROR) << "Get real path failed. path=" << path; + return; + } + ChangeFileMode(realpath.value(), S_IRWXU); + std::ofstream fout(realpath.value()); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!"; + return; + } + fout << cache_.dump(); + fout.close(); + // Set file mode to read only by user + ChangeFileMode(realpath.value(), S_IRUSR); +} + +thread_local static std::string local_threadid; +void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) { + std::ostringstream buffer; + buffer << caller << "." << std::this_thread::get_id(); + local_threadid = buffer.str(); +} +std::mutex AnalysisResultCacheMgr::tiggerToken_; +std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; } + +void AnalysisResultCacheMgr::PushTowait(const std::shared_future &future0, + const std::shared_future &future1) { + std::lock_guard lock(lock_); + waiting_.push_back(future0); + waiting_.push_back(future1); +} + +void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) { + std::lock_guard lock(lock_); + todo_.push_back(conf); +} + +void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) { + std::lock_guard lock(lock_); + AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf); + if (async_eval_result == nullptr) { + async_eval_result = std::make_shared(); + switch_cache_.set(conf, async_eval_result); + } +} + +EvalResultPtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) { + AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf); + // Conf has been visited and set value. + if (async_eval_result != nullptr) { + // Maybe blocked for waiting. AsyncEvalResult maybe null, if time out. + auto result = async_eval_result->GetResult(); + if (result == nullptr) { + result = std::make_shared(std::make_shared(), nullptr); + MS_LOG(ERROR) << "AsyncEvalResult for NodeConfig " << conf->ToString() << " is nullptr, maybe timeout."; + } + return result; + } + return nullptr; +} + +void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr arg) { + MS_EXCEPTION_IF_NULL(conf); + if (arg == nullptr || arg->abstract() == nullptr) { + MS_LOG(WARNING) << conf->ToString() << " value is nullptr"; + } + std::lock_guard lock(lock_); + AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf); + if (async_eval_result == nullptr) { + MS_LOG(EXCEPTION) << conf->ToString() << " Not key."; + async_eval_result = std::make_shared(); + async_eval_result->JoinResult(arg); + switch_cache_.set(conf, async_eval_result); + } else { + auto ab1 = async_eval_result->TryGetResult(); + AbstractBasePtrList absList; + if (ab1 != nullptr) { + absList.push_back(arg->abstract()); + absList.push_back(ab1->abstract()); + // Join two branches's result + auto joined_spec = AbstractJoin(absList); + MS_EXCEPTION_IF_NULL(joined_spec); + MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); + auto joined_result = std::make_shared(joined_spec, std::make_shared()); + async_eval_result->JoinResult(joined_result); + if (joined_result != ab1) { + PushTodo(conf); + } + } else { + async_eval_result->JoinResult(arg); + } + } +} + +void AnalysisResultCacheMgr::Todo() { + while (true) { + AnfNodeConfigPtr conf; + lock_.lock(); + if (!todo_.empty()) { + conf = todo_.front(); + } else { + lock_.unlock(); + break; + } + todo_.pop_front(); + lock_.unlock(); + if (!(*GetValue(conf)->abstract() == *GetSwitchValue(conf)->abstract())) { + MS_LOG(WARNING) << " Switch Value is not eq. " + << " switchCache: " << GetSwitchValue(conf)->abstract()->ToString() + << " globleCache: " << GetValue(conf)->abstract()->ToString() << "\t\tConf: " << conf->ToString(); + } + } +} + +void AnalysisResultCacheMgr::Wait() { + while (true) { + std::shared_future future; + lock_.lock(); + if (!waiting_.empty()) { + future = std::move(waiting_.front()); + } else { + lock_.unlock(); + break; + } + + waiting_.pop_front(); + lock_.unlock(); + future.wait(); + } + if (IS_OUTPUT_ON(DEBUG)) { + Todo(); + } +} + +std::string ArgsToString(const AbstractBasePtrList &args_spec_list) { + std::ostringstream buffer; + buffer << "("; + for (const auto &item : args_spec_list) { + buffer << item->ToString() << " # "; + } + buffer << " )"; + return buffer.str(); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h new file mode 100644 index 00000000000..525d89a5c94 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/async_eval_result.h @@ -0,0 +1,192 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" + +namespace mindspore { +namespace abstract { +constexpr size_t kInferTimeout = 60; + +template +class MultiThreadCache { + public: + using iterator = typename CacheType::iterator; + using const_iterator = typename CacheType::const_iterator; + + ValueType get(const KeyType &key) { + std::lock_guard lock(lock_); + auto it = cache_.find(key); + if (it != cache_.end()) { + return it->second; + } + return nullptr; + } + + void set(const KeyType &key, const ValueType &data) { + std::lock_guard lock(lock_); + cache_[key] = data; + } + + void clear() { + std::lock_guard lock(lock_); + cache_.clear(); + } + + size_t size() { return cache_.size(); } + + bool empty() { return size() == 0; } + + std::string dump() { + std::ostringstream buf; + for (auto &item : cache_) { + buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl; + } + return buf.str(); + } + + iterator begin() { return cache_.begin(); } + iterator end() { return cache_.end(); } + + const_iterator begin() const { return cache_.cbegin(); } + const_iterator end() const { return cache_.cend(); } + + const_iterator cbegin() const { return cache_.cbegin(); } + const_iterator cend() const { return cache_.cend(); } + + private: + std::mutex lock_; + CacheType cache_; +}; + +class AsyncEvalResult; +using AsyncEvalResultPtr = std::shared_ptr; + +using EvaluatorCacheMap = + std::unordered_map; +using EvalResultCache = MultiThreadCache; + +class AsyncEvalResult { + public: + AsyncEvalResult() = default; + ~AsyncEvalResult() = default; + // wait + EvalResultPtr GetResult(); + // not wait + EvalResultPtr TryGetResult(int ms = 0); + void JoinResult(const EvalResultPtr &result); + std::string ToString(); + + private: + EvalResultPtr result_{nullptr}; + std::mutex lock_; + std::condition_variable condition_var_; +}; + +class EvaluatorCacheMgr { + public: + EvaluatorCacheMgr() = default; + ~EvaluatorCacheMgr() = default; + + void Clear() { eval_result_cache_.clear(); } + EvalResultCache &GetCache() { return eval_result_cache_; } + EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); } + void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); } + size_t GetSize() { return eval_result_cache_.size(); } + + private: + EvalResultCache eval_result_cache_; +}; + +// AnalysisCache +class AnalysisResultCacheMgr { + public: + ~AnalysisResultCacheMgr() = default; + AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete; + AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete; + static AnalysisResultCacheMgr &GetInstance(); + static std::mutex &tiggerToken() { return tiggerToken_; } + void Clear(); + + using AnalysisConfigAsyncResultMap = + std::unordered_map; + using AnalysisConfigAsyncResultCache = + MultiThreadCache; + + using AnalysisConfigResultMap = + std::unordered_map; + using AnalysisConfigResultCache = MultiThreadCache; + + inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); } + inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); } + + // Dump all the conf and result + void DumpCache(const std::string &filename); + // Wait for async Eval(conf) to finish. + void Wait(); + void PushTowait(const std::shared_future &future0, const std::shared_future &future1); + void PushTodo(const AnfNodeConfigPtr &conf); + void Todo(); + static void UpdateCaller(const std::string &caller); + static std::string &GetThreadid(); + + void InitSwitchValue(const AnfNodeConfigPtr &conf); + EvalResultPtr GetSwitchValue(const AnfNodeConfigPtr &conf); + void SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr vale); + + private: + AnalysisResultCacheMgr() = default; + + static std::mutex tiggerToken_; + std::recursive_mutex lock_; + std::list> waiting_; + std::list todo_; + + AnalysisConfigResultCache cache_; + AnalysisConfigAsyncResultCache switch_cache_; +}; + +class TiggerToken_scoped_release { + public: + TiggerToken_scoped_release() { AnalysisResultCacheMgr::tiggerToken().unlock(); } + ~TiggerToken_scoped_release() { AnalysisResultCacheMgr::tiggerToken().lock(); } +}; +class TiggerToken_scoped_acquire { + public: + TiggerToken_scoped_acquire() { AnalysisResultCacheMgr::tiggerToken().lock(); } + ~TiggerToken_scoped_acquire() { AnalysisResultCacheMgr::tiggerToken().unlock(); } +}; +std::string ArgsToString(const AbstractBasePtrList &args_spec_list); + +inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisResultCacheMgr::GetThreadid() + ":"; } + +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index b134c3b74e4..9c793970242 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -25,6 +25,7 @@ #include "debug/trace.h" #include "utils/ms_context.h" #include "pipeline/jit/static_analysis/stack_frame.h" +#include "pipeline/jit/static_analysis/async_eval_result.h" namespace mindspore { namespace abstract { @@ -80,10 +81,9 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co // Increase & Check the func graph call depth. engine->IncreaseFunctionCallDepth(); engine->IncreaseStackFrameDepth(); - if (engine->function_call_depth() - engine->stack_frame_depth() > - MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH)) { - MS_LOG(EXCEPTION) << "Exceed function call depth limit " - << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) + const uint32_t max_depth = MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH); + if (engine->function_call_depth() > max_depth) { + MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth << ", (function call depth: " << engine->function_call_depth() << ", simulate call depth: " << engine->stack_frame_depth() << "), please call 'context.set_context(max_call_depth=value)' to adjust this value."; @@ -116,7 +116,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr auto current_stack_frame = std::make_shared(shared_from_base(), fg, context_, parent_context_); MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame; stack_frames.push(current_stack_frame); - while (1) { + while (true) { current_stack_frame = stack_frames.top(); if (current_stack_frame->Done()) { MS_EXCEPTION_IF_NULL(res_base); @@ -135,8 +135,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr // Save func graph eval result for specialize. auto evaluator = current_stack_frame->evaluator(); MS_EXCEPTION_IF_NULL(evaluator); - EvaluatorCacheMap &evaluator_cache_map = *evaluator->evaluator_cache_map(); - evaluator_cache_map[current_stack_frame->args_abs_list()] = eval_result; + evaluator->evaluator_cache_mgr()->SetValue(current_stack_frame->args_abs_list(), eval_result); // Leave current func graph. LeaveStackFrame(engine, current_stack_frame); @@ -167,7 +166,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) { const AnfNodePtr &func_node = fg->get_return(); - const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType { + const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType { if (node->isa() || node->isa()) { return EXCLUDE; } @@ -180,20 +179,26 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine << ", node_conf: " << node_conf->ToString(); auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf); res_base = node_eval_result->abstract(); - MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString() - << ", node_conf: " << node_conf->ToString() << ", abstract: " << res_base->ToString(); + MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString(); } MS_EXCEPTION_IF_NULL(res_base); return res_base; } EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) { + auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list); + if (eval_result != nullptr) { + MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong."; + return eval_result; + } else { + MS_LOG(DEBUG) << ToString() << " entered first."; + } + MS_EXCEPTION_IF_NULL(engine); engine->IncreaseFunctionCallDepth(); - if (engine->function_call_depth() - engine->stack_frame_depth() > - MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH)) { - MS_LOG(EXCEPTION) << "Exceed function call depth limit " - << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) + const uint32_t max_depth = MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH); + if (engine->function_call_depth() > max_depth) { + MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth << ", (function call depth: " << engine->function_call_depth() << ", simulate call depth: " << engine->stack_frame_depth() << "), please call 'context.set_context(max_call_depth=value)' to adjust this value."; @@ -212,6 +217,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << args_abs_list.size() << "."; } MS_EXCEPTION_IF_NULL(parent_context_); + MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { "; + if (parent_context_->func_graph() != nullptr) { + MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisResultCacheMgr::GetThreadid() << ":" + << parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":" + << fg->ToString() << "();"; + } + context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list); const auto ¶meters = fg->parameters(); for (size_t i = 0; i < nargs; i++) { @@ -219,6 +231,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr const auto &node = parameters[i]; AnfNodeConfigPtr conf = engine->MakeConfig(node, context_); engine->SaveEvalResultInCache(conf, std::make_shared(arg, nullptr)); + MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString(); } MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString() << ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString() @@ -238,11 +251,14 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr res_base = std::make_shared(); } + MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << res_base->ToString(); engine->DecreaseFunctionCallDepth(); MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString() << "), leave, function call depth: " << engine->function_call_depth() << " - " << engine->stack_frame_depth(); - return std::make_shared(res_base, nullptr); + auto res = std::make_shared(res_base, nullptr); + evaluator_cache_mgr_->SetValue(args_abs_list, res); + return res; } AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { @@ -280,6 +296,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { return args_spec_list; } + if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { if (parent_context_) { MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() @@ -307,7 +324,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa return joined_args_spec_list_1; } } - if (trace_.size() != 0) { + if (!trace_.empty()) { MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); // Join the last eval arguments and current arguments to check if there are loop variant. @@ -336,27 +353,26 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { auto iter = func_graph_cache_.find(args_spec_list); - FuncGraphPtr ret = nullptr; + FuncGraphPtr res; if (iter == func_graph_cache_.end()) { auto fg = func_graph(); MS_EXCEPTION_IF_NULL(fg); - TraceGuard guard(std::make_shared(fg->debug_info())); FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list); func_graph_cache_[args_spec_list] = generated_graph; MS_EXCEPTION_IF_NULL(engine); engine->func_graph_manager()->AddFuncGraph(generated_graph); - ret = generated_graph; + res = generated_graph; } else { - ret = iter->second; + res = iter->second; } // For the top graph, if it is replaced by generated graph, update the top graph to the new one. if (parse::Parser::GetTopFuncGraph() == func_graph()) { - if (ret != func_graph()) { - parse::Parser::UpdateTopFuncGraph(ret); + if (res != func_graph()) { + parse::Parser::UpdateTopFuncGraph(res); } } - return ret; + return res; } FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { @@ -366,7 +382,7 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons } MS_EXCEPTION_IF_NULL(meta_func_graph_); - FuncGraphPtr generated_func_graph = nullptr; + FuncGraphPtr generated_func_graph; if (this->bound_node() != nullptr) { TraceGuard trace_guard(std::make_shared(bound_node()->debug_info())); generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); @@ -381,8 +397,18 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons return cloned_func_graph; } -EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { - const std::string &evaluator_name = ToString(); +EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) { + const string evaluator_name = ToString(); + std::unique_lock eval_lock(eval_loc_, std::try_to_lock); + if (!eval_lock.owns_lock()) { + auto py_tstate = PyEval_SaveThread(); + eval_lock.try_lock_for(std::chrono::seconds(kInferTimeout)); + PyEval_RestoreThread(py_tstate); + if (!eval_lock.owns_lock()) { + MS_LOG(EXCEPTION) << "It is timeout to run " << ToString(); + } + } AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), @@ -394,30 +420,30 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args args_spec_list = BroadenUndeterminedArgs(args_spec_list); trace::TraceGraphEvalEnter(shared_from_base(), out_conf); MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base(), args_spec_list, out_conf); - MS_EXCEPTION_IF_NULL(evaluator_cache_map_); - auto iter = evaluator_cache_map_->find(args_spec_list); - if (iter == evaluator_cache_map_->end()) { + MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_); + auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list); + if (eval_result == nullptr) { MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; - EvalResultPtr ret = Eval(engine, args_spec_list); - if (ret->abstract() == nullptr) { + EvalResultPtr res = Eval(engine, args_spec_list); + if (res->abstract() == nullptr) { EvalFailLogging(shared_from_base(), args_spec_list, out_conf); MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; } - MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; - (*evaluator_cache_map_)[args_spec_list] = ret; + MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << res->abstract()->ToString() << "."; + evaluator_cache_mgr_->SetValue(args_spec_list, res); trace::TraceGraphEvalLeave(shared_from_base()); - return ret; + return res; } else { - MS_EXCEPTION_IF_NULL(iter->second); - MS_EXCEPTION_IF_NULL(iter->second->abstract()); - MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; + MS_EXCEPTION_IF_NULL(eval_result); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << "."; trace::TraceGraphEvalLeave(shared_from_base()); - return iter->second; + return eval_result; } } EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr) { + const AnfNodeConfigPtr &) { AbstractBasePtrList args_spec_list; auto is_py_eval = (identifier_ == "PythonPrimEvaluator"); (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), @@ -432,58 +458,57 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt } return abstract; }); - EvalResultPtr ret = EvalPrim(engine, args_spec_list); - return ret; + return EvalPrim(engine, args_spec_list); } EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { + const AnfNodeConfigPtr &out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->ObtainEvalResult()->abstract(); }); - if (args_conf_list.size() == 0) { + if (args_conf_list.empty()) { MS_LOG(EXCEPTION) << "Size should greater than 0"; } - EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); + EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); // No need to cache. - return ret; + return res; } -EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { - EvalResultPtr ret = EvalPrim(args_conf_list); - return ret; +EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &) { + return EvalPrim(args_conf_list); } EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { + const AnfNodeConfigPtr &out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->ObtainEvalResult()->abstract(); }); - EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); + EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf); // Don't lookup from cache, as different out_conf with same node but different context // may add different entry to anfnode_config_map_, like getattr primitive. - (*evaluator_cache_map_)[args_spec_list] = ret; - return ret; + evaluator_cache_mgr_->SetValue(args_spec_list, res); + return res; } EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { + const AnfNodeConfigPtr &out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->ObtainEvalResult()->abstract(); }); - MS_EXCEPTION_IF_NULL(evaluator_cache_map_); - auto iter = evaluator_cache_map_->find(args_spec_list); - if (iter != evaluator_cache_map_->end()) { - return iter->second; + MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_); + auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list); + if (eval_result != nullptr) { + return eval_result; } ConfigPtrList partial_args_conf_list; @@ -493,23 +518,22 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); - - (*evaluator_cache_map_)[args_spec_list] = ret; - return ret; + EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf); + evaluator_cache_mgr_->SetValue(args_spec_list, res); + return res; } -EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { +EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); return conf->ObtainEvalResult()->abstract(); }); - MS_EXCEPTION_IF_NULL(evaluator_cache_map_); - auto iter = evaluator_cache_map_->find(args_spec_list); - if (iter != evaluator_cache_map_->end()) { - return iter->second; + MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_); + auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list); + if (eval_result != nullptr) { + return eval_result; } // Call the original evaluator, get the result: y = f(x) @@ -536,9 +560,9 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg // J(f)(J(x)) return a tuple (y, bprop_f) AbstractBasePtrList jargs = {result->abstract(), bprop}; AbstractBasePtr jtuple = std::make_shared(jargs); - auto infer_reuslt = std::make_shared(jtuple, std::make_shared()); - (*evaluator_cache_map_)[args_spec_list] = infer_reuslt; - return infer_reuslt; + auto res = std::make_shared(jtuple, std::make_shared()); + evaluator_cache_mgr_->SetValue(args_spec_list, res); + return res; } EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { @@ -553,5 +577,16 @@ EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrLis } return std::make_shared(output_, std::make_shared()); } +EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) { + auto result = this->Run(engine, args_conf_list, out_conf); + + StaticAnalysisException::Instance().CheckException(); + pybind11::gil_scoped_release release; + AnalysisResultCacheMgr::GetInstance().Wait(); + StaticAnalysisException::Instance().CheckException(); + + return result; +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h index cd20f28a1fe..78b85192a69 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -26,23 +26,22 @@ #include #include "pipeline/jit/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/async_eval_result.h" #include "utils/ms_context.h" namespace mindspore { namespace abstract { -using EvaluatorCacheMap = - std::unordered_map; -using EvaluatorCacheMapPtr = std::shared_ptr; - +using EvaluatorCacheMgrPtr = std::shared_ptr; using EvaluatorAttrMap = std::unordered_map; -using EvaluatorAttrMapPtr = std::shared_ptr; +using EvaluatorAttrCache = MultiThreadCache; +using EvaluatorAttrCachePtr = std::shared_ptr; class Evaluator : public Base { public: explicit Evaluator(const std::string &id) - : evaluator_cache_map_(std::make_shared()), - attr_cache_(std::make_shared()), + : evaluator_cache_mgr_(std::make_shared()), + attr_cache_(std::make_shared()), identifier_(id) {} ~Evaluator() override = default; MS_DECLARE_PARENT(Evaluator, Base); @@ -50,9 +49,12 @@ class Evaluator : public Base { // difference between Run() and Eval(): // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. // Run() will modify cache_ member, so it cannot marked as const; - virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); + virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf); virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; + virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf); virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } @@ -87,14 +89,16 @@ class Evaluator : public Base { virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } - EvaluatorCacheMapPtr &evaluator_cache_map() { return evaluator_cache_map_; } - EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } + EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; } + EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; } + + EvaluatorCacheMgrPtr evaluator_cache_mgr_; + EvaluatorAttrCachePtr attr_cache_; - EvaluatorCacheMapPtr evaluator_cache_map_; - EvaluatorAttrMapPtr attr_cache_; std::string identifier_; AnfNodeWeakPtr bound_node_; + std::recursive_timed_mutex eval_loc_; }; class PrimEvaluator : public Evaluator { @@ -112,7 +116,8 @@ class TrivialPrimEvaluator : public PrimEvaluator { explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} ~TrivialPrimEvaluator() override = default; MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) final; virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; }; @@ -121,7 +126,8 @@ class TransitionPrimEvaluator : public PrimEvaluator { explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} ~TransitionPrimEvaluator() override = default; MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) final; // Parameter in_conf0 : the first element in args_conf_list; virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; @@ -132,7 +138,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator { explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} ~SymbolicPrimEvaluator() override = default; MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) final; virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; }; @@ -173,7 +180,8 @@ class TrackedEvaluator : public Evaluator { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) override; std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } private: @@ -211,9 +219,9 @@ class BaseFuncGraphEvaluator : public Evaluator { AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg); // Add functions for stack frame routine. AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg); - void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, - const StackFramePtr &new_stack_frame); - void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame); + static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, + const StackFramePtr &new_stack_frame); + static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame); AnalysisContextPtr context_; }; @@ -287,7 +295,8 @@ class PartialAppEvaluator : public Evaluator { MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) override; std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } private: @@ -333,7 +342,8 @@ class JEvaluator : public Evaluator { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + const AnfNodeConfigPtr &out_conf) override; std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } private: diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index b82e96b24be..765f556af95 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -51,7 +51,7 @@ std::unordered_set prims_to_skip_undetermined_infer{ "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { + const AnfNodeConfigPtr &out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); }); @@ -136,7 +136,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s } EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { + const AnfNodeConfigPtr &out_conf) { if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } @@ -238,7 +238,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac } EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { + const AnfNodeConfigPtr &out_conf) { AbstractBasePtrList args_spec_list; if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; @@ -603,7 +603,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c return ret_abstract; } } - if (prim_->prim_type() == PrimType::kPrimTypePyCheck) { return EvalPyCheckPrim(engine, args); } @@ -640,10 +639,12 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs } MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); - const auto &iter = evaluator_cache_map_->find(args); - if (iter != evaluator_cache_map_->end()) { - return iter->second; + const auto eval_result = evaluator_cache_mgr_->GetValue(args); + if (eval_result != nullptr) { + return eval_result; } + + pybind11::gil_scoped_acquire gil; auto py_args = PreparePyInputs(prim_py_, args); prim_py_->BeginRecordAddAttr(); py::dict output = prim_py_->RunInfer(py_args); @@ -653,7 +654,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs auto res_spec = PyInferRes2Abstract(prim_py_, output); MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); - (*evaluator_cache_map_)[args] = infer_result; + evaluator_cache_mgr_->SetValue(args, infer_result); return infer_result; } @@ -1016,6 +1017,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { return nullptr; } AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract(); + AbstractRefPtr ref_abs = abs->cast(); if (ref_abs == nullptr) { MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); @@ -1082,7 +1084,7 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { } // don't lookup from cache, as different out_conf with same node but different context // may add different entry to anfnode_config_map, like getattr primitive; - (*evaluator_cache_map_)[args_spec_list] = ret; + evaluator_cache_mgr_->SetValue(args_spec_list, ret); return ret; } }; @@ -1169,7 +1171,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); auto infer_result = std::make_shared(ret, std::make_shared()); - (*evaluator_cache_map_)[args_spec_list] = infer_result; + evaluator_cache_mgr_->SetValue(args_spec_list, infer_result); return infer_result; } @@ -1197,7 +1199,7 @@ class PartialEvaluator : public Evaluator { PartialEvaluator() : Evaluator("PartialEvaluator") {} ~PartialEvaluator() override = default; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf = nullptr) override { + const AnfNodeConfigPtr &out_conf = nullptr) override { if (args_conf_list.size() == 0) { MS_LOG(EXCEPTION) << "Args size should be greater than 0"; } @@ -1214,7 +1216,7 @@ class PartialEvaluator : public Evaluator { MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() << " as func is: " << arg0_value->ToString(); auto eval_result = std::make_shared(ret, std::make_shared()); - (*evaluator_cache_map_)[args_spec_list] = eval_result; + evaluator_cache_mgr_->SetValue(args_spec_list, eval_result); return eval_result; } auto func = CheckArg("partial", args_spec_list, 0); @@ -1248,7 +1250,7 @@ class PartialEvaluator : public Evaluator { auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); auto eval_result = std::make_shared(ret, std::make_shared()); - (*evaluator_cache_map_)[args_spec_list] = eval_result; + evaluator_cache_mgr_->SetValue(args_spec_list, eval_result); return eval_result; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 4dfcfd265ac..ae128713f4b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -71,7 +71,7 @@ class DoSignatureEvaluator : public Evaluator { explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} ~DoSignatureEvaluator() override = default; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; + const AnfNodeConfigPtr &out_config = nullptr) override; EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; @@ -86,7 +86,7 @@ class UnpackGraphEvaluator : public Evaluator { explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} ~UnpackGraphEvaluator() override = default; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; + const AnfNodeConfigPtr &out_config = nullptr) override; EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; @@ -102,7 +102,7 @@ class MixedPrecisionCastEvaluator : public Evaluator { : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} ~MixedPrecisionCastEvaluator() override = default; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; + const AnfNodeConfigPtr &out_config = nullptr) override; EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index 9d9d142a714..094d5479ba1 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -494,11 +494,11 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n return wrapped_node; } -const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { +const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { auto cache_iter = evalcaches_.find(eval); if (cache_iter == evalcaches_.end()) { - evalcaches_[eval] = eval->evaluator_cache_map(); - return eval->evaluator_cache_map(); + evalcaches_[eval] = eval->evaluator_cache_mgr(); + return eval->evaluator_cache_mgr(); } return cache_iter->second; } @@ -509,7 +509,8 @@ std::pair FuncGraphSpecializer::BuildFromB std::unordered_set choices; EvalResultPtr ret = nullptr; AbstractBasePtrList broaded_argvals; - for (auto &argvals_map : *evalcaches_[eval]) { + EvalResultCache &cache = evalcaches_[eval]->GetCache(); + for (auto &argvals_map : cache) { auto argvals = argvals_map.first; broaded_argvals.clear(); @@ -524,11 +525,10 @@ std::pair FuncGraphSpecializer::BuildFromB (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared(v); }); - // if broaden return null - ret = eval->Run(engine_, args_conf_list, nullptr); - EvaluatorCacheMapPtr real = std::make_shared(); - - (*real)[broaded_argvals] = ret; + // If broaden return null + ret = eval->SingleRun(engine_, args_conf_list, nullptr); + EvaluatorCacheMgrPtr real = std::make_shared(); + real->SetValue(broaded_argvals, ret); evalcaches_[eval] = real; return std::make_pair(broaded_argvals, ret->abstract()); } else { @@ -615,11 +615,12 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { } namespace { -void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) { +void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) { MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items."; int64_t i = 0; - for (const auto &item : evaluator_cache_map) { - MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first; + const EvalResultCache &map = evaluator_cache_mgr->GetCache(); + for (const auto &item : map) { + MS_LOG(DEBUG) << "evaluator_cache[" << i++ << "]: " << item.first; } } @@ -650,24 +651,24 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(result); - EvaluatorCacheMap &evaluator_cache_map = *eval->evaluator_cache_map(); - if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { - *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); + EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr(); + auto data = evaluator_cache_mgr->GetValue(argvals); + if (data != nullptr) { + *result = std::make_pair(argvals, data->abstract()); return kSpecializeSuccess; } - DumpEvaluatorCache(evaluator_cache_map, argvals); + DumpEvaluatorCache(evaluator_cache_mgr, argvals); - const EvaluatorCacheMapPtr &choices = GetEvalCache(eval); - MS_EXCEPTION_IF_NULL(choices); - - if (choices->count(argvals)) { - *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); + MS_EXCEPTION_IF_NULL(GetEvalCache(eval)); + EvalResultCache &choices = GetEvalCache(eval)->GetCache(); + if (choices.get(argvals) != nullptr) { + *result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract()); return kSpecializeSuccess; - } else if (choices->size() == 1) { + } else if (choices.size() == 1) { MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; - *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); + *result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract()); return kSpecializeSuccess; - } else if (choices->empty()) { + } else if (choices.empty()) { MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | " << func->type_name(); return kSpecializeFindUniqueArgvalDead; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h index ba7ba085697..a2db9251867 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -91,7 +91,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this *repl_node_; std::vector todo_; std::unordered_set marked_; - std::unordered_map evalcaches_; + std::unordered_map evalcaches_; void FirstPass(); void SecondPass(); @@ -127,7 +127,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this *result); // Get cache, it may be eval's cache or cache built from broaded argument values. - const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval); + const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval); // Try to build unique argvals from the broaded arg vals if it is unique. std::pair BuildFromBroadedArgsVal(const EvaluatorPtr &eval); void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc index 5cd0c711177..5cb9631f73c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc @@ -16,6 +16,7 @@ #include "pipeline/jit/static_analysis/stack_frame.h" #include "debug/trace.h" +#include "pipeline/jit/static_analysis/async_eval_result.h" namespace mindspore { namespace abstract { @@ -66,8 +67,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode); // Check if already evaluated before. - EvaluatorCacheMap &evaluator_cache_map = *evaluator->evaluator_cache_map(); - if (evaluator_cache_map.find(args_abs_list) != evaluator_cache_map.end()) { + if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) { return nullptr; } @@ -128,6 +128,8 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) { << ", current_context_: " << current_context_->ToString(); AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_); auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf); + MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() + << ") = " << node_eval_result->abstract()->ToString(); return node_eval_result; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index f207e3b9646..5680d6b2f0a 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -17,20 +17,21 @@ */ #include "pipeline/jit/static_analysis/static_analysis.h" - #include #include - +#include "abstract/abstract_value.h" #include "abstract/utils.h" #include "pipeline/jit/static_analysis/prim.h" #include "frontend/operator/ops.h" #include "utils/symbolic.h" +#include "utils/ms_exception.h" #include "ir/tensor.h" #include "ir/func_graph_cloner.h" #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/static_analysis/evaluator.h" #include "debug/trace.h" #include "debug/anf_ir_dump.h" +#include "pipeline/jit/static_analysis/async_eval_result.h" namespace mindspore { namespace abstract { @@ -39,12 +40,9 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) { auto v = arg_spec->GetValueTrack(); if (v->isa()) { return true; - } else { - return false; } - } else { - return false; } + return false; } AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) { @@ -54,36 +52,6 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase return nullptr; } -void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { - MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() - << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() - << ", Pointer: " << result->abstract().get(); - analysis_cache_map_[conf] = result; - - // Set intermediate abstract value. - if (IsIntermediateAbstract(result->abstract())) { - if (conf->node()->intermediate_abstract() == nullptr) { - conf->node()->set_intermediate_abstract(result->abstract()); - MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); - } else { - auto old_spec = conf->node()->intermediate_abstract(); - auto joined_spec = IntermediateJoin(result->abstract(), old_spec); - conf->node()->set_intermediate_abstract(joined_spec); - MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" - << result->abstract()->ToString() << "\njoined_spec:\t" - << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); - } - } -} - -EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { - auto value = analysis_cache_map_.find(conf); - if (value == analysis_cache_map_.end()) { - return nullptr; - } - return value->second; -} - std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const { MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf->node()); @@ -105,6 +73,7 @@ bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeCon } AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { + StaticAnalysisException::Instance().ClearException(); ConfigPtrList args_conf_list; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); @@ -127,6 +96,11 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac MS_EXCEPTION_IF_NULL(output_conf); result.inferred = output_conf->ObtainEvalResult(); result.context = root_context; + + StaticAnalysisException::Instance().CheckException(); + pybind11::gil_scoped_release release; + AnalysisResultCacheMgr::GetInstance().Wait(); + StaticAnalysisException::Instance().CheckException(); return result; } @@ -137,15 +111,35 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana return eval->context(); } +void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(result); + static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance(); + cache_mgr.SetValue(conf, result); + + // Set intermediate abstract value. + if (IsIntermediateAbstract(result->abstract())) { + if (conf->node()->intermediate_abstract() == nullptr) { + conf->node()->set_intermediate_abstract(result->abstract()); + MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); + } else { + auto old_spec = conf->node()->intermediate_abstract(); + auto joined_spec = IntermediateJoin(result->abstract(), old_spec); + conf->node()->set_intermediate_abstract(joined_spec); + MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" + << result->abstract()->ToString() << "\njoined_spec:\t" + << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); + } + } +} + EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); - EvalResultPtr result = analysis_cache_.GetValue(conf); + static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance(); + auto result = cache_mgr.GetValue(conf); if (result != nullptr) { - MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() - << ", Value: " << result->abstract().get() << ", " << result->abstract()->ToString(); return result; } - MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString(); result = Eval(conf); if (result == nullptr) { @@ -187,8 +181,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { trace::TraceEvalCNodeLeave(); } else { MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name() - << "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph") - << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + << "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph"); } #ifdef DEBUG @@ -280,6 +273,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; return std::make_shared(maybe_func->Clone(), std::make_shared()); } + AbstractFunctionPtr func = dyn_cast(maybe_func); if (func == nullptr) { MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << maybe_func->ToString() << "."; @@ -321,28 +315,28 @@ EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const Abs } void AnalysisEngine::ClearEvaluatorCache() { - for (std::pair element : evaluators_) { + for (auto &element : evaluators_) { EvaluatorPtr evaluator = element.second; MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map()); - evaluator->evaluator_cache_map()->clear(); + MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr()); + evaluator->evaluator_cache_mgr()->Clear(); } for (auto &element : prim_constructors_) { EvaluatorPtr evaluator = element.second; MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map()); - evaluator->evaluator_cache_map()->clear(); + MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr()); + evaluator->evaluator_cache_mgr()->Clear(); } for (auto &element : prim_py_evaluators_) { EvaluatorPtr evaluator = element.second; MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map()); - evaluator->evaluator_cache_map()->clear(); + MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr()); + evaluator->evaluator_cache_mgr()->Clear(); } } void AnalysisEngine::Clear() { - analysis_cache_.Clear(); + AnalysisResultCacheMgr::GetInstance().Clear(); anfnode_config_map_.clear(); eval_trace_.clear(); evaluators_.clear(); @@ -517,6 +511,11 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { if (func->tracking_id() != nullptr) { MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); } + MS_EXCEPTION_IF_NULL(func); + + // protect the constructors + static std::recursive_mutex constructors_mutex; + // std::lock_guard lock(constructors_mutex); if (func->tracking_id() == nullptr || func->isa() || func->isa()) { EvaluatorPtr evaluator = _GetEvaluatorFor(func); @@ -570,7 +569,16 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector MS_EXCEPTION_IF_NULL(eval); return eval->Run(shared_from_this(), args_conf_list, out_conf); } +#if !(defined _WIN32 || defined _WIN64) + static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1"); + if (enable_singleThread) { + return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); + } else { + return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list); + } +#else return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); +#endif } void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { @@ -623,17 +631,17 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vectorToString() << "check undetermined."; auto &alternate_evaluator = multi_poss_[u_eval.evaluator_]; - auto &eval_cache = alternate_evaluator->evaluator_cache_map(); + auto eval_cache = alternate_evaluator->evaluator_cache_mgr(); const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list); if ((!undetermined_evals.count(alt_eval_args)) && - (((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) || - (eval_cache->find(args_spec_list) == eval_cache->end()))) { + (((!continued_evals_.count(u_eval)) && (eval_cache->GetValue(args_spec_list) != nullptr)) || + (eval_cache->GetValue(args_spec_list) == nullptr))) { MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined."; has_undetermined = true; break; } } - if (has_undetermined == false) { + if (!has_undetermined) { MS_LOG(DEBUG) << eval->ToString() << "has no undetermined."; *continue_flag = true; return latest_entry; @@ -675,7 +683,7 @@ std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBa } EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) { - if (out_specs.size() == 0) { + if (out_specs.empty()) { MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; } @@ -710,6 +718,135 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_ return std::make_shared(joined_spec, std::make_shared()); } +bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) { + if (abstract->isa()) { + return true; + } + if (abstract->isa()) { + auto elements = abstract->cast()->elements(); + if (std::any_of(elements.begin(), elements.end(), + [](const AbstractBasePtr &item) { return item->isa(); })) { + return true; + } + } + return false; +} + +EvalResultPtr ExecEvaluetor(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, + AnfNodeConfigPtr out_conf, std::string caller, AsyncEvalResultPtr async_result_branch, + AsyncEvalResultPtr async_result_main) { + // TiggerToken_scoped_acquire tigger_token_acquire; + py::gil_scoped_acquire pyGuard; + + EvalResultPtr result = nullptr; + try { + AnalysisResultCacheMgr::UpdateCaller(caller); + result = eval->Run(engine, args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(result); + async_result_branch->JoinResult(result); + async_result_main->JoinResult(result); + MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString() + << " asyncResult address = " << async_result_branch.get() + << " value = " << async_result_branch->TryGetResult()->abstract()->ToString(); + auto broadAbstract = result->abstract()->Broaden(); + auto broadEvalResult = std::make_shared(broadAbstract, nullptr); + AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadEvalResult); + } catch (const std::exception &e) { + std::ostringstream oss; + trace::GetEvalStackInfo(oss); + if (!oss.str().empty()) { + MS_LOG(ERROR) << oss.str(); + } + auto abstractErrPtr = std::make_shared(std::make_shared(oss.str()), out_conf->node()); + async_result_main->JoinResult(std::make_shared(abstractErrPtr, nullptr)); + StaticAnalysisException::Instance().SetException(); + } + return result; +} + +EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list) { + // TiggerToken_scoped_release tigger_token_release; + pybind11::gil_scoped_release release; + // Wait for the switch node to finish. + MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString(); + auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf); + if (eval_result == nullptr) { + MS_LOG(DEBUG) << GetInferThread() << "async : Init switch " << out_conf->ToString(); + AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf); + } else { + if (eval_result->isa()) { + MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out." + << "please check the code if there are recursive functions."; + } + return eval_result; + } + + // Eval result of the branches and main. + AsyncEvalResultPtr asyncResult0 = std::make_shared(); + AsyncEvalResultPtr asyncResult1 = std::make_shared(); + AsyncEvalResultPtr asyncResult_main = std::make_shared(); + SetUndeterminedFlag(evaluators[0]); + SetUndeterminedFlag(evaluators[1]); + std::string threadId = AnalysisResultCacheMgr::GetThreadid(); + + MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString(); + auto future0 = std::async(std::launch::async, ExecEvaluetor, evaluators[0], shared_from_this(), args_conf_list, + out_conf, threadId, asyncResult0, asyncResult_main); + + MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString(); + auto future1 = std::async(std::launch::async, ExecEvaluetor, evaluators[1], shared_from_this(), args_conf_list, + out_conf, threadId, asyncResult1, asyncResult_main); + + // Wait for async threads to finish. + AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1)); + + MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString() + << " or " << evaluators[1]->ToString(); + auto branchResult = asyncResult_main->GetResult(); + if (branchResult == nullptr || branchResult->isa()) { + MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString() + << "please check the code if there are recursive functions."; + } + if (branchResult->isa()) { + MS_LOG(EXCEPTION) << "async " << out_conf->node()->ToString() << " threw exception."; + } + + AbstractBasePtrList out_specs; + if (NeedWaitForTwoBranches(branchResult->abstract())) { + MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[0]->ToString(); + auto result0 = asyncResult0->GetResult(); + if (result0->isa()) { + MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << "is time out." + << " Please check the code if there is recursive function."; + } + out_specs.push_back(result0->abstract()); + + MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[1]->ToString(); + auto result1 = asyncResult1->GetResult(); + if (result1->isa()) { + MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << "is time out." + << " Please check the code if there is recursive function."; + } + out_specs.push_back(result1->abstract()); + } else { + if (asyncResult0->TryGetResult()) { + MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[0]->ToString() + << " value0=" << asyncResult0->GetResult()->abstract()->ToString(); + out_specs.push_back(asyncResult0->GetResult()->abstract()); + } + + if (asyncResult1->TryGetResult()) { + MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[1]->ToString() + << " value1=" << asyncResult1->GetResult()->abstract()->ToString(); + out_specs.push_back(asyncResult1->GetResult()->abstract()); + } + } + + return ProcessEvalResults(out_specs, out_conf->node()); +} + EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { @@ -726,9 +863,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorObtainEvalResult()->abstract(); }); - for (auto eval : evaluators) { + for (const auto &eval : evaluators) { SetUndeterminedFlag(eval); - const auto current_inf = EvaluatorArgs(eval, args_spec_list); MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. @@ -757,11 +893,11 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorevaluator_) { MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString(); - auto latest_entry_eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(latest_entry_eval_result->abstract()); + auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString() - << " return out_spec: " << latest_entry_eval_result->abstract()->ToString(); - return latest_entry_eval_result; + << " return out_spec: " << eval_result->abstract()->ToString(); + return eval_result; } } } @@ -815,8 +951,9 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont if (value->isa()) { auto prim = value->cast(); return MakeAbstractClosure(prim, anf_node); + } else { + return value->ToAbstract(); } - return value->ToAbstract(); } AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 9d00eebb2ad..02e58d125ad 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -28,6 +28,7 @@ #include #include #include +#include #ifdef DEBUG #include @@ -154,19 +155,6 @@ class VirtualConfig : public Config { AbstractBasePtr abstract_; }; -// AnalysisCache -class AnalysisCache { - public: - AnalysisCache() = default; - ~AnalysisCache() = default; - void Clear() { analysis_cache_map_.clear(); } - void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); - EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); - - private: - std::unordered_map analysis_cache_map_; -}; - using PrimEvaluatorMap = std::unordered_map; using AnfNodeConfigMap = std::unordered_map; @@ -183,12 +171,38 @@ struct PartialAppHasher { return h1 ^ h2; } }; + +// Should compare Args based on value other than pointer; +struct EvaluatorArgs { + EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {} + bool operator==(const EvaluatorArgs &other) const { + if (evaluator_ != other.evaluator_) { + return false; + } + if (AbstractBasePtrListDeepEqual(args_, other.args_)) { + return true; + } + return false; + } + bool operator!=(const EvaluatorArgs &other) { return !(*this == other); } + + EvaluatorPtr evaluator_; + AbstractBasePtrList args_; +}; +using EvalTraceRevIter = std::list::reverse_iterator; +struct EvaluatorArgsHasher { + std::size_t operator()(const EvaluatorArgs &eval_args) const { + return hash_combine(std::hash{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_)); + } +}; +struct EvaluatorArgsEqual { + bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; } +}; + class AnalysisEngine : public std::enable_shared_from_this { public: AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) - : analysis_cache_(AnalysisCache()), - prim_constructors_(prim_evaluator_map), - func_graph_manager_(func_graph_manager) { + : prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) { function_call_depth_ = 0; function_call_max_depth_ = 0; stack_frame_depth_ = 0; @@ -202,11 +216,7 @@ class AnalysisEngine : public std::enable_shared_from_this { // func_graph: The func_graph to analyze. // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); - void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(result); - analysis_cache_.set_value(conf, result); - } + void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result); EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf); // Return the Evaluator for the given function. EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); @@ -218,7 +228,6 @@ class AnalysisEngine : public std::enable_shared_from_this { EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); void Clear(); void ClearEvaluatorCache(); - AnalysisCache &analysis_cache() { return analysis_cache_; } AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) { return std::make_shared(shared_from_this(), node, context); } @@ -239,7 +248,6 @@ class AnalysisEngine : public std::enable_shared_from_this { EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf); const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } - AnalysisCache analysis_cache_; std::unordered_map prim_py_evaluators_; void ResetFunctionCallDepth() { @@ -249,7 +257,7 @@ class AnalysisEngine : public std::enable_shared_from_this { void IncreaseFunctionCallDepth() { function_call_depth_++; if (function_call_max_depth_ < function_call_depth_) { - function_call_max_depth_ = function_call_depth_; + function_call_max_depth_ = function_call_depth_.load(); } } void DecreaseFunctionCallDepth() { @@ -268,7 +276,7 @@ class AnalysisEngine : public std::enable_shared_from_this { void IncreaseStackFrameDepth() { stack_frame_depth_++; if (stack_frame_max_depth_ < stack_frame_depth_) { - stack_frame_max_depth_ = stack_frame_depth_; + stack_frame_max_depth_ = stack_frame_depth_.load(); } } void DecreaseStackFrameDepth() { @@ -279,39 +287,10 @@ class AnalysisEngine : public std::enable_shared_from_this { } size_t stack_frame_depth() const { return stack_frame_depth_; } size_t stack_frame_max_depth() const { return stack_frame_max_depth_; } - void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf); - bool enable_recursive_eval() const { return enable_recursive_eval_; } private: - // Should compare Args based on value other than pointer; - struct EvaluatorArgs { - EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {} - bool operator==(const EvaluatorArgs &other) const { - if (evaluator_ != other.evaluator_) { - return false; - } - if (AbstractBasePtrListDeepEqual(args_, other.args_)) { - return true; - } - return false; - } - bool operator!=(const EvaluatorArgs &other) { return !(*this == other); } - - EvaluatorPtr evaluator_; - AbstractBasePtrList args_; - }; - using EvalTraceRevIter = std::list::reverse_iterator; - struct EvaluatorArgsHasher { - std::size_t operator()(const EvaluatorArgs &eval_args) const { - return hash_combine(std::hash{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_)); - } - }; - struct EvaluatorArgsEqual { - bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; } - }; - void SetUndeterminedFlag(const EvaluatorPtr &evaluator); EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, @@ -323,6 +302,7 @@ class AnalysisEngine : public std::enable_shared_from_this { std::unordered_map evaluators_; std::unordered_map, EvaluatorPtr, PartialAppHasher> constructors_app_; + AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list eval_trace_; @@ -337,15 +317,18 @@ class AnalysisEngine : public std::enable_shared_from_this { const ConfigPtrList &args_conf_list); EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list); + EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); // Record current depth of function call stack, including `stack_frame_depth_`. - size_t function_call_depth_; - size_t function_call_max_depth_; + std::atomic_long function_call_depth_; + std::atomic_long function_call_max_depth_; // Record current depth of stack frames call. - size_t stack_frame_depth_; - size_t stack_frame_max_depth_; + std::atomic_long stack_frame_depth_; + std::atomic_long stack_frame_max_depth_; - size_t forward_count_; + std::atomic_long forward_count_; bool enable_recursive_eval_; diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 8d2bb730ca3..7ed9356d88b 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -162,6 +162,10 @@ AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { if (context->get_param(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) { return AbstractBase::Broaden(config); } else { + auto type = this->BuildType()->type_id(); + if (type < kNumberTypeBegin || type > kNumberTypeEnd) { + return AbstractBase::Broaden(config); + } return Clone(); } } @@ -1085,6 +1089,27 @@ std::string AbstractNull::ToString() const { return buffer.str(); } +bool AbstractTimeOut::operator==(const AbstractTimeOut &) const { return true; } + +bool AbstractTimeOut::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (other.isa()) { + auto other_none = static_cast(&other); + return *this == *other_none; + } else { + return false; + } +} + +std::string AbstractTimeOut::ToString() const { + std::ostringstream buffer; + buffer << "AbstractTimeOut " + << "(Value: Null)"; + return buffer.str(); +} + bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } bool AbstractEllipsis::operator==(const AbstractBase &other) const { diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index f59afcc3c0d..0580a60f101 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -312,7 +312,7 @@ class AbstractTensor : public AbstractUndetermined { AbstractBasePtr Clone() const override; AbstractBasePtr Broaden(uint8_t config = 0) const override; AbstractBasePtr BroadenWithShape() const; - AbstractBasePtr Join(const AbstractBasePtr &other); + AbstractBasePtr Join(const AbstractBasePtr &other) override; bool operator==(const AbstractTensor &other) const; bool operator==(const AbstractBase &other) const override; std::string ToString() const override; @@ -565,6 +565,21 @@ class AbstractNull : public AbstractBase { }; using AbstractNullPtr = std::shared_ptr; +// the timeout state value for variable, which means the variable is not assigned because it is timeout +class AbstractTimeOut : public AbstractBase { + public: + AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared()); } + ~AbstractTimeOut() override = default; + MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractTimeOut &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; +}; +using AbstractTimeOutPtr = std::shared_ptr; + class AbstractEllipsis : public AbstractBase { public: AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared()); } diff --git a/mindspore/core/utils/info.cc b/mindspore/core/utils/info.cc index 925b328809c..1908469d337 100644 --- a/mindspore/core/utils/info.cc +++ b/mindspore/core/utils/info.cc @@ -230,7 +230,7 @@ DebugInfoPtr TraceManager::GetParseOrResolveDebugInfo() { return TraceManager::p void TraceManager::ClearParseOrResolveDebugInfo() { TraceManager::parse_or_resolve_debug_info_ = nullptr; } -std::stack TraceManager::trace_context_stack_; +thread_local std::stack TraceManager::trace_context_stack_; -DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr; +thread_local DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr; } // namespace mindspore diff --git a/mindspore/core/utils/info.h b/mindspore/core/utils/info.h index ced6aab8765..4c57cc53c29 100644 --- a/mindspore/core/utils/info.h +++ b/mindspore/core/utils/info.h @@ -79,8 +79,8 @@ class TraceManager { static void ClearParseOrResolveDebugInfo(); static DebugInfoPtr GetParseOrResolveDebugInfo(); - static std::stack trace_context_stack_; - static DebugInfoPtr parse_or_resolve_debug_info_; + thread_local static std::stack trace_context_stack_; + thread_local static DebugInfoPtr parse_or_resolve_debug_info_; }; class TraceGuard { diff --git a/mindspore/core/utils/ms_exception.h b/mindspore/core/utils/ms_exception.h index 0e8462b1dac..76168012fa1 100644 --- a/mindspore/core/utils/ms_exception.h +++ b/mindspore/core/utils/ms_exception.h @@ -58,6 +58,45 @@ class MsException { ExceptionListener *listener_{nullptr}; std::exception_ptr exception_ptr_{nullptr}; }; + +class StaticAnalysisException { + public: + static StaticAnalysisException &Instance() { + static StaticAnalysisException instance; + return instance; + } + + void ClearException() { exception_ptr_ = nullptr; } + + bool HasException() { return exception_ptr_ != nullptr; } + + void SetException() { + if (exception_ptr_ != nullptr) { + return; + } + exception_ptr_ = std::current_exception(); + } + + void SetAndRethrowException() { + SetException(); + std::rethrow_exception(std::current_exception()); + } + + void CheckException() { + if (exception_ptr_ != nullptr) { + auto tmp_exception_ptr = exception_ptr_; + exception_ptr_ = nullptr; + std::rethrow_exception(tmp_exception_ptr); + } + } + + private: + StaticAnalysisException() = default; + ~StaticAnalysisException() = default; + DISABLE_COPY_AND_ASSIGN(StaticAnalysisException) + + std::exception_ptr exception_ptr_{nullptr}; +}; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ diff --git a/tests/ut/cpp/abstract/abstract_test.cc b/tests/ut/cpp/abstract/abstract_test.cc index bbcfa660115..5a14d19a6b3 100644 --- a/tests/ut/cpp/abstract/abstract_test.cc +++ b/tests/ut/cpp/abstract/abstract_test.cc @@ -18,6 +18,7 @@ #include "common/common_test.h" +#include "pybind11/pybind11.h" #include "pipeline/jit/static_analysis/static_analysis.h" #include "abstract/utils.h" #include "pipeline/jit/static_analysis/prim.h" @@ -37,6 +38,12 @@ class TestAbstract : public UT::Common { }; TEST_F(TestAbstract, TestParseDataClass) { + // Check initialization before callback to Python. + if (Py_IsInitialized() == 0) { + Py_Initialize(); + } + PyEval_InitThreads(); + py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo"); ClassPtr cls_ptr = parse::ParseDataClass(fn);