infer_optv3:infer optimize to find exit of branches and resolve the

endless issue.
This commit is contained in:
lanzhineng 2021-06-20 19:57:10 +08:00
parent 45902803b2
commit fc8ec7fc49
20 changed files with 945 additions and 269 deletions

View File

@ -34,6 +34,7 @@
#include "debug/anf_ir_utils.h"
#include "debug/common.h"
#include "pipeline/jit/static_analysis/evaluator.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
#include "utils/log_adapter.h"
#include "abstract/abstract_value.h"
@ -177,7 +178,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(engine_);
auto cfg = engine_->MakeConfig(node, cur_ctx_);
auto ret = engine_->analysis_cache().GetValue(cfg);
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
if (ret == nullptr) {
return "Undefined";
}
@ -190,7 +191,7 @@ AbstractBasePtr AnalyzedFuncGraphExporter::GetNodeAbstract(const AnfNodePtr &nod
}
MS_EXCEPTION_IF_NULL(engine_);
auto cfg = engine_->MakeConfig(node, cur_ctx_);
auto ret = engine_->analysis_cache().GetValue(cfg);
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
return ret == nullptr ? nullptr : ret->abstract();
}
@ -548,9 +549,9 @@ void GetEvalStackInfo(std::ostringstream &oss) {
}
// trace the graph evaluator stack
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
thread_local static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
thread_local static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) {

View File

@ -106,7 +106,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

View File

@ -30,6 +30,7 @@
#include "pipeline/jit/pass.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#include "debug/anf_ir_utils.h"
@ -728,10 +729,12 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
MS_LOG(DEBUG) << PrintArgs(args);
ret_value = CompileInner(obj, args, phase, use_vm);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
std::string exception_info = GetCompileExceptionInfo();
if (!exception_info.empty()) {
MS_LOG(ERROR) << exception_info;
if (!StaticAnalysisException::Instance().HasException()) {
// print function call stack info before release
std::string exception_info = GetCompileExceptionInfo();
if (!exception_info.empty()) {
MS_LOG(ERROR) << exception_info;
}
}
ReleaseResource(phase);
@ -1281,6 +1284,7 @@ void ClearResAtexit() {
ReleaseGeTsd();
parse::python_adapter::ResetPythonScope();
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
#ifdef ENABLE_DEBUGGER
Debugger::GetInstance()->Reset();
#endif

View File

@ -0,0 +1,223 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/jit/static_analysis/async_eval_result.h"
#include <chrono>
#include "utils/symbolic.h"
#include "debug/common.h"
#include "pipeline/jit/base.h"
#include "utils/utils.h"
#include "abstract/utils.h"
namespace mindspore {
namespace abstract {
EvalResultPtr AsyncEvalResult::TryGetResult(int ms) {
if (result_ != nullptr || ms == 0) {
return result_;
}
std::unique_lock<std::mutex> lock(lock_);
auto time = std::chrono::microseconds(ms);
// Wait for ms.
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
return result_;
}
EvalResultPtr AsyncEvalResult::GetResult() {
if (result_ != nullptr) {
return result_;
}
std::unique_lock<std::mutex> lock(lock_);
auto time = std::chrono::seconds(kInferTimeout);
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
return result_;
}
std::string AsyncEvalResult::ToString() {
std::ostringstream buffer;
std::lock_guard<std::mutex> lock(lock_);
buffer << (result_ == nullptr ? "NOT SET" : result_->abstract()->ToString());
return buffer.str();
}
void AsyncEvalResult::JoinResult(const EvalResultPtr &result) {
MS_EXCEPTION_IF_NULL(result);
{
std::lock_guard<std::mutex> lock(lock_);
result_ = result;
}
condition_var_.notify_all();
}
void AnalysisResultCacheMgr::Clear() {
cache_.clear();
switch_cache_.clear();
todo_.clear();
}
AnalysisResultCacheMgr &AnalysisResultCacheMgr::GetInstance() {
static AnalysisResultCacheMgr instance;
return instance;
}
void AnalysisResultCacheMgr::DumpCache(const std::string &filename) {
auto path = pipeline::GetSaveGraphsPathName(Common::AddId(filename, ".cache"));
auto realpath = Common::GetRealPath(path);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed. path=" << path;
return;
}
ChangeFileMode(realpath.value(), S_IRWXU);
std::ofstream fout(realpath.value());
if (!fout.is_open()) {
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!";
return;
}
fout << cache_.dump();
fout.close();
// Set file mode to read only by user
ChangeFileMode(realpath.value(), S_IRUSR);
}
thread_local static std::string local_threadid;
void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
std::ostringstream buffer;
buffer << caller << "." << std::this_thread::get_id();
local_threadid = buffer.str();
}
std::mutex AnalysisResultCacheMgr::tiggerToken_;
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
void AnalysisResultCacheMgr::PushTowait(const std::shared_future<EvalResultPtr> &future0,
const std::shared_future<EvalResultPtr> &future1) {
std::lock_guard<std::recursive_mutex> lock(lock_);
waiting_.push_back(future0);
waiting_.push_back(future1);
}
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
std::lock_guard<std::recursive_mutex> lock(lock_);
todo_.push_back(conf);
}
void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
std::lock_guard<std::recursive_mutex> lock(lock_);
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
if (async_eval_result == nullptr) {
async_eval_result = std::make_shared<AsyncEvalResult>();
switch_cache_.set(conf, async_eval_result);
}
}
EvalResultPtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
// Conf has been visited and set value.
if (async_eval_result != nullptr) {
// Maybe blocked for waiting. AsyncEvalResult maybe null, if time out.
auto result = async_eval_result->GetResult();
if (result == nullptr) {
result = std::make_shared<EvalResult>(std::make_shared<AbstractTimeOut>(), nullptr);
MS_LOG(ERROR) << "AsyncEvalResult for NodeConfig " << conf->ToString() << " is nullptr, maybe timeout.";
}
return result;
}
return nullptr;
}
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr arg) {
MS_EXCEPTION_IF_NULL(conf);
if (arg == nullptr || arg->abstract() == nullptr) {
MS_LOG(WARNING) << conf->ToString() << " value is nullptr";
}
std::lock_guard<std::recursive_mutex> lock(lock_);
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
if (async_eval_result == nullptr) {
MS_LOG(EXCEPTION) << conf->ToString() << " Not key.";
async_eval_result = std::make_shared<AsyncEvalResult>();
async_eval_result->JoinResult(arg);
switch_cache_.set(conf, async_eval_result);
} else {
auto ab1 = async_eval_result->TryGetResult();
AbstractBasePtrList absList;
if (ab1 != nullptr) {
absList.push_back(arg->abstract());
absList.push_back(ab1->abstract());
// Join two branches's result
auto joined_spec = AbstractJoin(absList);
MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
auto joined_result = std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
async_eval_result->JoinResult(joined_result);
if (joined_result != ab1) {
PushTodo(conf);
}
} else {
async_eval_result->JoinResult(arg);
}
}
}
void AnalysisResultCacheMgr::Todo() {
while (true) {
AnfNodeConfigPtr conf;
lock_.lock();
if (!todo_.empty()) {
conf = todo_.front();
} else {
lock_.unlock();
break;
}
todo_.pop_front();
lock_.unlock();
if (!(*GetValue(conf)->abstract() == *GetSwitchValue(conf)->abstract())) {
MS_LOG(WARNING) << " Switch Value is not eq. "
<< " switchCache: " << GetSwitchValue(conf)->abstract()->ToString()
<< " globleCache: " << GetValue(conf)->abstract()->ToString() << "\t\tConf: " << conf->ToString();
}
}
}
void AnalysisResultCacheMgr::Wait() {
while (true) {
std::shared_future<EvalResultPtr> future;
lock_.lock();
if (!waiting_.empty()) {
future = std::move(waiting_.front());
} else {
lock_.unlock();
break;
}
waiting_.pop_front();
lock_.unlock();
future.wait();
}
if (IS_OUTPUT_ON(DEBUG)) {
Todo();
}
}
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
std::ostringstream buffer;
buffer << "(";
for (const auto &item : args_spec_list) {
buffer << item->ToString() << " # ";
}
buffer << " )";
return buffer.str();
}
} // namespace abstract
} // namespace mindspore

View File

@ -0,0 +1,192 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
#include <iostream>
#include <utility>
#include <future>
#include <thread>
#include <memory>
#include <unordered_map>
#include <vector>
#include <string>
#include <functional>
#include <list>
#include <fstream>
#include "pipeline/jit/static_analysis/static_analysis.h"
namespace mindspore {
namespace abstract {
constexpr size_t kInferTimeout = 60;
template <typename KeyType, typename ValueType, typename CacheType>
class MultiThreadCache {
public:
using iterator = typename CacheType::iterator;
using const_iterator = typename CacheType::const_iterator;
ValueType get(const KeyType &key) {
std::lock_guard<std::mutex> lock(lock_);
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second;
}
return nullptr;
}
void set(const KeyType &key, const ValueType &data) {
std::lock_guard<std::mutex> lock(lock_);
cache_[key] = data;
}
void clear() {
std::lock_guard<std::mutex> lock(lock_);
cache_.clear();
}
size_t size() { return cache_.size(); }
bool empty() { return size() == 0; }
std::string dump() {
std::ostringstream buf;
for (auto &item : cache_) {
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
}
return buf.str();
}
iterator begin() { return cache_.begin(); }
iterator end() { return cache_.end(); }
const_iterator begin() const { return cache_.cbegin(); }
const_iterator end() const { return cache_.cend(); }
const_iterator cbegin() const { return cache_.cbegin(); }
const_iterator cend() const { return cache_.cend(); }
private:
std::mutex lock_;
CacheType cache_;
};
class AsyncEvalResult;
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
using EvaluatorCacheMap =
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvalResultCache = MultiThreadCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
class AsyncEvalResult {
public:
AsyncEvalResult() = default;
~AsyncEvalResult() = default;
// wait
EvalResultPtr GetResult();
// not wait
EvalResultPtr TryGetResult(int ms = 0);
void JoinResult(const EvalResultPtr &result);
std::string ToString();
private:
EvalResultPtr result_{nullptr};
std::mutex lock_;
std::condition_variable condition_var_;
};
class EvaluatorCacheMgr {
public:
EvaluatorCacheMgr() = default;
~EvaluatorCacheMgr() = default;
void Clear() { eval_result_cache_.clear(); }
EvalResultCache &GetCache() { return eval_result_cache_; }
EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
size_t GetSize() { return eval_result_cache_.size(); }
private:
EvalResultCache eval_result_cache_;
};
// AnalysisCache
class AnalysisResultCacheMgr {
public:
~AnalysisResultCacheMgr() = default;
AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
static AnalysisResultCacheMgr &GetInstance();
static std::mutex &tiggerToken() { return tiggerToken_; }
void Clear();
using AnalysisConfigAsyncResultMap =
std::unordered_map<AnfNodeConfigPtr, AsyncEvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
using AnalysisConfigAsyncResultCache =
MultiThreadCache<AnfNodeConfigPtr, AsyncEvalResultPtr, AnalysisConfigAsyncResultMap>;
using AnalysisConfigResultMap =
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
using AnalysisConfigResultCache = MultiThreadCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
// Dump all the conf and result
void DumpCache(const std::string &filename);
// Wait for async Eval(conf) to finish.
void Wait();
void PushTowait(const std::shared_future<EvalResultPtr> &future0, const std::shared_future<EvalResultPtr> &future1);
void PushTodo(const AnfNodeConfigPtr &conf);
void Todo();
static void UpdateCaller(const std::string &caller);
static std::string &GetThreadid();
void InitSwitchValue(const AnfNodeConfigPtr &conf);
EvalResultPtr GetSwitchValue(const AnfNodeConfigPtr &conf);
void SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr vale);
private:
AnalysisResultCacheMgr() = default;
static std::mutex tiggerToken_;
std::recursive_mutex lock_;
std::list<std::shared_future<EvalResultPtr>> waiting_;
std::list<AnfNodeConfigPtr> todo_;
AnalysisConfigResultCache cache_;
AnalysisConfigAsyncResultCache switch_cache_;
};
class TiggerToken_scoped_release {
public:
TiggerToken_scoped_release() { AnalysisResultCacheMgr::tiggerToken().unlock(); }
~TiggerToken_scoped_release() { AnalysisResultCacheMgr::tiggerToken().lock(); }
};
class TiggerToken_scoped_acquire {
public:
TiggerToken_scoped_acquire() { AnalysisResultCacheMgr::tiggerToken().lock(); }
~TiggerToken_scoped_acquire() { AnalysisResultCacheMgr::tiggerToken().unlock(); }
};
std::string ArgsToString(const AbstractBasePtrList &args_spec_list);
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisResultCacheMgr::GetThreadid() + ":"; }
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_

View File

@ -25,6 +25,7 @@
#include "debug/trace.h"
#include "utils/ms_context.h"
#include "pipeline/jit/static_analysis/stack_frame.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
namespace mindspore {
namespace abstract {
@ -80,10 +81,9 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
// Increase & Check the func graph call depth.
engine->IncreaseFunctionCallDepth();
engine->IncreaseStackFrameDepth();
if (engine->function_call_depth() - engine->stack_frame_depth() >
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
if (engine->function_call_depth() > max_depth) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
<< ", (function call depth: " << engine->function_call_depth()
<< ", simulate call depth: " << engine->stack_frame_depth()
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
@ -116,7 +116,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_);
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
stack_frames.push(current_stack_frame);
while (1) {
while (true) {
current_stack_frame = stack_frames.top();
if (current_stack_frame->Done()) {
MS_EXCEPTION_IF_NULL(res_base);
@ -135,8 +135,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
// Save func graph eval result for specialize.
auto evaluator = current_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
EvaluatorCacheMap &evaluator_cache_map = *evaluator->evaluator_cache_map();
evaluator_cache_map[current_stack_frame->args_abs_list()] = eval_result;
evaluator->evaluator_cache_mgr()->SetValue(current_stack_frame->args_abs_list(), eval_result);
// Leave current func graph.
LeaveStackFrame(engine, current_stack_frame);
@ -167,7 +166,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
const AnfNodePtr &func_node = fg->get_return();
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return EXCLUDE;
}
@ -180,20 +179,26 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
<< ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
res_base = node_eval_result->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << res_base->ToString();
MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString();
}
MS_EXCEPTION_IF_NULL(res_base);
return res_base;
}
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
if (eval_result != nullptr) {
MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong.";
return eval_result;
} else {
MS_LOG(DEBUG) << ToString() << " entered first.";
}
MS_EXCEPTION_IF_NULL(engine);
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() - engine->stack_frame_depth() >
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
if (engine->function_call_depth() > max_depth) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
<< ", (function call depth: " << engine->function_call_depth()
<< ", simulate call depth: " << engine->stack_frame_depth()
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
@ -212,6 +217,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< args_abs_list.size() << ".";
}
MS_EXCEPTION_IF_NULL(parent_context_);
MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
if (parent_context_->func_graph() != nullptr) {
MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisResultCacheMgr::GetThreadid() << ":"
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
<< fg->ToString() << "();";
}
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
const auto &parameters = fg->parameters();
for (size_t i = 0; i < nargs; i++) {
@ -219,6 +231,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, context_);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
}
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
@ -238,11 +251,14 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
res_base = std::make_shared<AbstractUndetermined>();
}
MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << res_base->ToString();
engine->DecreaseFunctionCallDepth();
MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
<< "), leave, function call depth: " << engine->function_call_depth() << " - "
<< engine->stack_frame_depth();
return std::make_shared<EvalResult>(res_base, nullptr);
auto res = std::make_shared<EvalResult>(res_base, nullptr);
evaluator_cache_mgr_->SetValue(args_abs_list, res);
return res;
}
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
@ -280,6 +296,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
return args_spec_list;
}
if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
if (parent_context_) {
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
@ -307,7 +324,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
return joined_args_spec_list_1;
}
}
if (trace_.size() != 0) {
if (!trace_.empty()) {
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back());
// Join the last eval arguments and current arguments to check if there are loop variant.
@ -336,27 +353,26 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
auto iter = func_graph_cache_.find(args_spec_list);
FuncGraphPtr ret = nullptr;
FuncGraphPtr res;
if (iter == func_graph_cache_.end()) {
auto fg = func_graph();
MS_EXCEPTION_IF_NULL(fg);
TraceGuard guard(std::make_shared<TraceEvaluatorGenGraph>(fg->debug_info()));
FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list);
func_graph_cache_[args_spec_list] = generated_graph;
MS_EXCEPTION_IF_NULL(engine);
engine->func_graph_manager()->AddFuncGraph(generated_graph);
ret = generated_graph;
res = generated_graph;
} else {
ret = iter->second;
res = iter->second;
}
// For the top graph, if it is replaced by generated graph, update the top graph to the new one.
if (parse::Parser::GetTopFuncGraph() == func_graph()) {
if (ret != func_graph()) {
parse::Parser::UpdateTopFuncGraph(ret);
if (res != func_graph()) {
parse::Parser::UpdateTopFuncGraph(res);
}
}
return ret;
return res;
}
FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
@ -366,7 +382,7 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
}
MS_EXCEPTION_IF_NULL(meta_func_graph_);
FuncGraphPtr generated_func_graph = nullptr;
FuncGraphPtr generated_func_graph;
if (this->bound_node() != nullptr) {
TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
@ -381,8 +397,18 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
return cloned_func_graph;
}
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) {
const std::string &evaluator_name = ToString();
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
const string evaluator_name = ToString();
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_loc_, std::try_to_lock);
if (!eval_lock.owns_lock()) {
auto py_tstate = PyEval_SaveThread();
eval_lock.try_lock_for(std::chrono::seconds(kInferTimeout));
PyEval_RestoreThread(py_tstate);
if (!eval_lock.owns_lock()) {
MS_LOG(EXCEPTION) << "It is timeout to run " << ToString();
}
}
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
@ -394,30 +420,30 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
auto iter = evaluator_cache_map_->find(args_spec_list);
if (iter == evaluator_cache_map_->end()) {
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result == nullptr) {
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
EvalResultPtr ret = Eval(engine, args_spec_list);
if (ret->abstract() == nullptr) {
EvalResultPtr res = Eval(engine, args_spec_list);
if (res->abstract() == nullptr) {
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
}
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
(*evaluator_cache_map_)[args_spec_list] = ret;
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << res->abstract()->ToString() << ".";
evaluator_cache_mgr_->SetValue(args_spec_list, res);
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return ret;
return res;
} else {
MS_EXCEPTION_IF_NULL(iter->second);
MS_EXCEPTION_IF_NULL(iter->second->abstract());
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << ".";
MS_EXCEPTION_IF_NULL(eval_result);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << ".";
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return iter->second;
return eval_result;
}
}
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr) {
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
@ -432,58 +458,57 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
return abstract;
});
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
return ret;
return EvalPrim(engine, args_spec_list);
}
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
if (args_conf_list.size() == 0) {
if (args_conf_list.empty()) {
MS_LOG(EXCEPTION) << "Size should greater than 0";
}
EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
// No need to cache.
return ret;
return res;
}
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
EvalResultPtr ret = EvalPrim(args_conf_list);
return ret;
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
return EvalPrim(args_conf_list);
}
EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf);
// Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive.
(*evaluator_cache_map_)[args_spec_list] = ret;
return ret;
evaluator_cache_mgr_->SetValue(args_spec_list, res);
return res;
}
EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
auto iter = evaluator_cache_map_->find(args_spec_list);
if (iter != evaluator_cache_map_->end()) {
return iter->second;
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return eval_result;
}
ConfigPtrList partial_args_conf_list;
@ -493,23 +518,22 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
(*evaluator_cache_map_)[args_spec_list] = ret;
return ret;
EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf);
evaluator_cache_mgr_->SetValue(args_spec_list, res);
return res;
}
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
auto iter = evaluator_cache_map_->find(args_spec_list);
if (iter != evaluator_cache_map_->end()) {
return iter->second;
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return eval_result;
}
// Call the original evaluator, get the result: y = f(x)
@ -536,9 +560,9 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
// J(f)(J(x)) return a tuple (y, bprop_f)
AbstractBasePtrList jargs = {result->abstract(), bprop};
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
(*evaluator_cache_map_)[args_spec_list] = infer_reuslt;
return infer_reuslt;
auto res = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, res);
return res;
}
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
@ -553,5 +577,16 @@ EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrLis
}
return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
}
EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
auto result = this->Run(engine, args_conf_list, out_conf);
StaticAnalysisException::Instance().CheckException();
pybind11::gil_scoped_release release;
AnalysisResultCacheMgr::GetInstance().Wait();
StaticAnalysisException::Instance().CheckException();
return result;
}
} // namespace abstract
} // namespace mindspore

View File

@ -26,23 +26,22 @@
#include <stack>
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace abstract {
using EvaluatorCacheMap =
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
using EvaluatorCacheMgrPtr = std::shared_ptr<EvaluatorCacheMgr>;
using EvaluatorAttrMap =
std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
using EvaluatorAttrCache = MultiThreadCache<AbstractBasePtrList, AttrValueMapPtr, EvaluatorAttrMap>;
using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>;
class Evaluator : public Base {
public:
explicit Evaluator(const std::string &id)
: evaluator_cache_map_(std::make_shared<EvaluatorCacheMap>()),
attr_cache_(std::make_shared<EvaluatorAttrMap>()),
: evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
attr_cache_(std::make_shared<EvaluatorAttrCache>()),
identifier_(id) {}
~Evaluator() override = default;
MS_DECLARE_PARENT(Evaluator, Base);
@ -50,9 +49,12 @@ class Evaluator : public Base {
// difference between Run() and Eval():
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
// Run() will modify cache_ member, so it cannot marked as const;
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf);
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf);
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
@ -87,14 +89,16 @@ class Evaluator : public Base {
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
EvaluatorCacheMapPtr &evaluator_cache_map() { return evaluator_cache_map_; }
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
EvaluatorCacheMgrPtr evaluator_cache_mgr_;
EvaluatorAttrCachePtr attr_cache_;
EvaluatorCacheMapPtr evaluator_cache_map_;
EvaluatorAttrMapPtr attr_cache_;
std::string identifier_;
AnfNodeWeakPtr bound_node_;
std::recursive_timed_mutex eval_loc_;
};
class PrimEvaluator : public Evaluator {
@ -112,7 +116,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~TrivialPrimEvaluator() override = default;
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) final;
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
};
@ -121,7 +126,8 @@ class TransitionPrimEvaluator : public PrimEvaluator {
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~TransitionPrimEvaluator() override = default;
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) final;
// Parameter in_conf0 : the first element in args_conf_list;
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
@ -132,7 +138,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~SymbolicPrimEvaluator() override = default;
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) final;
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
};
@ -173,7 +180,8 @@ class TrackedEvaluator : public Evaluator {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
private:
@ -211,9 +219,9 @@ class BaseFuncGraphEvaluator : public Evaluator {
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
// Add functions for stack frame routine.
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame);
void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame);
static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame);
static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame);
AnalysisContextPtr context_;
};
@ -287,7 +295,8 @@ class PartialAppEvaluator : public Evaluator {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
}
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private:
@ -333,7 +342,8 @@ class JEvaluator : public Evaluator {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
}
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private:

View File

@ -51,7 +51,7 @@ std::unordered_set<std::string> prims_to_skip_undetermined_infer{
"MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
@ -136,7 +136,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
}
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
const AnfNodeConfigPtr &out_conf) {
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
@ -238,7 +238,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
}
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
@ -603,7 +603,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
return ret_abstract;
}
}
if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
return EvalPyCheckPrim(engine, args);
}
@ -640,10 +639,12 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
}
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
const auto &iter = evaluator_cache_map_->find(args);
if (iter != evaluator_cache_map_->end()) {
return iter->second;
const auto eval_result = evaluator_cache_mgr_->GetValue(args);
if (eval_result != nullptr) {
return eval_result;
}
pybind11::gil_scoped_acquire gil;
auto py_args = PreparePyInputs(prim_py_, args);
prim_py_->BeginRecordAddAttr();
py::dict output = prim_py_->RunInfer(py_args);
@ -653,7 +654,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
auto res_spec = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
(*evaluator_cache_map_)[args] = infer_result;
evaluator_cache_mgr_->SetValue(args, infer_result);
return infer_result;
}
@ -1016,6 +1017,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
return nullptr;
}
AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
@ -1082,7 +1084,7 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
}
// don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map, like getattr primitive;
(*evaluator_cache_map_)[args_spec_list] = ret;
evaluator_cache_mgr_->SetValue(args_spec_list, ret);
return ret;
}
};
@ -1169,7 +1171,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*evaluator_cache_map_)[args_spec_list] = infer_result;
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
return infer_result;
}
@ -1197,7 +1199,7 @@ class PartialEvaluator : public Evaluator {
PartialEvaluator() : Evaluator("PartialEvaluator") {}
~PartialEvaluator() override = default;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf = nullptr) override {
const AnfNodeConfigPtr &out_conf = nullptr) override {
if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
}
@ -1214,7 +1216,7 @@ class PartialEvaluator : public Evaluator {
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString();
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*evaluator_cache_map_)[args_spec_list] = eval_result;
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
return eval_result;
}
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
@ -1248,7 +1250,7 @@ class PartialEvaluator : public Evaluator {
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*evaluator_cache_map_)[args_spec_list] = eval_result;
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
return eval_result;
}

View File

@ -71,7 +71,7 @@ class DoSignatureEvaluator : public Evaluator {
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
~DoSignatureEvaluator() override = default;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
const AnfNodeConfigPtr &out_config = nullptr) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
@ -86,7 +86,7 @@ class UnpackGraphEvaluator : public Evaluator {
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
~UnpackGraphEvaluator() override = default;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
const AnfNodeConfigPtr &out_config = nullptr) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
@ -102,7 +102,7 @@ class MixedPrecisionCastEvaluator : public Evaluator {
: Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {}
~MixedPrecisionCastEvaluator() override = default;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
const AnfNodeConfigPtr &out_config = nullptr) override;
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";

View File

@ -494,11 +494,11 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
return wrapped_node;
}
const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
auto cache_iter = evalcaches_.find(eval);
if (cache_iter == evalcaches_.end()) {
evalcaches_[eval] = eval->evaluator_cache_map();
return eval->evaluator_cache_map();
evalcaches_[eval] = eval->evaluator_cache_mgr();
return eval->evaluator_cache_mgr();
}
return cache_iter->second;
}
@ -509,7 +509,8 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
EvalResultPtr ret = nullptr;
AbstractBasePtrList broaded_argvals;
for (auto &argvals_map : *evalcaches_[eval]) {
EvalResultCache &cache = evalcaches_[eval]->GetCache();
for (auto &argvals_map : cache) {
auto argvals = argvals_map.first;
broaded_argvals.clear();
@ -524,11 +525,10 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
(void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
[](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
// if broaden return null
ret = eval->Run(engine_, args_conf_list, nullptr);
EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
(*real)[broaded_argvals] = ret;
// If broaden return null
ret = eval->SingleRun(engine_, args_conf_list, nullptr);
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
real->SetValue(broaded_argvals, ret);
evalcaches_[eval] = real;
return std::make_pair(broaded_argvals, ret->abstract());
} else {
@ -615,11 +615,12 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
}
namespace {
void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) {
MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
int64_t i = 0;
for (const auto &item : evaluator_cache_map) {
MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
const EvalResultCache &map = evaluator_cache_mgr->GetCache();
for (const auto &item : map) {
MS_LOG(DEBUG) << "evaluator_cache[" << i++ << "]: " << item.first;
}
}
@ -650,24 +651,24 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
MS_EXCEPTION_IF_NULL(eval);
MS_EXCEPTION_IF_NULL(result);
EvaluatorCacheMap &evaluator_cache_map = *eval->evaluator_cache_map();
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
*result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
auto data = evaluator_cache_mgr->GetValue(argvals);
if (data != nullptr) {
*result = std::make_pair(argvals, data->abstract());
return kSpecializeSuccess;
}
DumpEvaluatorCache(evaluator_cache_map, argvals);
DumpEvaluatorCache(evaluator_cache_mgr, argvals);
const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
MS_EXCEPTION_IF_NULL(choices);
if (choices->count(argvals)) {
*result = std::make_pair(argvals, (*choices)[argvals]->abstract());
MS_EXCEPTION_IF_NULL(GetEvalCache(eval));
EvalResultCache &choices = GetEvalCache(eval)->GetCache();
if (choices.get(argvals) != nullptr) {
*result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract());
return kSpecializeSuccess;
} else if (choices->size() == 1) {
} else if (choices.size() == 1) {
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
*result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
return kSpecializeSuccess;
} else if (choices->empty()) {
} else if (choices.empty()) {
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
<< func->type_name();
return kSpecializeFindUniqueArgvalDead;

View File

@ -91,7 +91,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_node_;
std::vector<AnfNodePtr> todo_;
std::unordered_set<AnfNodePtr> marked_;
std::unordered_map<EvaluatorPtr, EvaluatorCacheMapPtr> evalcaches_;
std::unordered_map<EvaluatorPtr, EvaluatorCacheMgrPtr> evalcaches_;
void FirstPass();
void SecondPass();
@ -127,7 +127,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
const AbstractBasePtrList &argvals,
std::pair<AbstractBasePtrList, AbstractBasePtr> *result);
// Get cache, it may be eval's cache or cache built from broaded argument values.
const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval);
const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval);
// Try to build unique argvals from the broaded arg vals if it is unique.
std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval);
void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node);

View File

@ -16,6 +16,7 @@
#include "pipeline/jit/static_analysis/stack_frame.h"
#include "debug/trace.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
namespace mindspore {
namespace abstract {
@ -66,8 +67,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
// Check if already evaluated before.
EvaluatorCacheMap &evaluator_cache_map = *evaluator->evaluator_cache_map();
if (evaluator_cache_map.find(args_abs_list) != evaluator_cache_map.end()) {
if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) {
return nullptr;
}
@ -128,6 +128,8 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
<< ", current_context_: " << current_context_->ToString();
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString()
<< ") = " << node_eval_result->abstract()->ToString();
return node_eval_result;
}

View File

@ -17,20 +17,21 @@
*/
#include "pipeline/jit/static_analysis/static_analysis.h"
#include <algorithm>
#include <set>
#include "abstract/abstract_value.h"
#include "abstract/utils.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "utils/symbolic.h"
#include "utils/ms_exception.h"
#include "ir/tensor.h"
#include "ir/func_graph_cloner.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/static_analysis/evaluator.h"
#include "debug/trace.h"
#include "debug/anf_ir_dump.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
namespace mindspore {
namespace abstract {
@ -39,12 +40,9 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
auto v = arg_spec->GetValueTrack();
if (v->isa<SymbolicKeyInstance>()) {
return true;
} else {
return false;
}
} else {
return false;
}
return false;
}
AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
@ -54,36 +52,6 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
return nullptr;
}
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
<< ", Pointer: " << result->abstract().get();
analysis_cache_map_[conf] = result;
// Set intermediate abstract value.
if (IsIntermediateAbstract(result->abstract())) {
if (conf->node()->intermediate_abstract() == nullptr) {
conf->node()->set_intermediate_abstract(result->abstract());
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
} else {
auto old_spec = conf->node()->intermediate_abstract();
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
conf->node()->set_intermediate_abstract(joined_spec);
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
<< result->abstract()->ToString() << "\njoined_spec:\t"
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
}
}
}
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
auto value = analysis_cache_map_.find(conf);
if (value == analysis_cache_map_.end()) {
return nullptr;
}
return value->second;
}
std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(conf->node());
@ -105,6 +73,7 @@ bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeCon
}
AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
StaticAnalysisException::Instance().ClearException();
ConfigPtrList args_conf_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
@ -127,6 +96,11 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
MS_EXCEPTION_IF_NULL(output_conf);
result.inferred = output_conf->ObtainEvalResult();
result.context = root_context;
StaticAnalysisException::Instance().CheckException();
pybind11::gil_scoped_release release;
AnalysisResultCacheMgr::GetInstance().Wait();
StaticAnalysisException::Instance().CheckException();
return result;
}
@ -137,15 +111,35 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
return eval->context();
}
void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(result);
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
cache_mgr.SetValue(conf, result);
// Set intermediate abstract value.
if (IsIntermediateAbstract(result->abstract())) {
if (conf->node()->intermediate_abstract() == nullptr) {
conf->node()->set_intermediate_abstract(result->abstract());
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
} else {
auto old_spec = conf->node()->intermediate_abstract();
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
conf->node()->set_intermediate_abstract(joined_spec);
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
<< result->abstract()->ToString() << "\njoined_spec:\t"
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
}
}
}
EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
EvalResultPtr result = analysis_cache_.GetValue(conf);
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
auto result = cache_mgr.GetValue(conf);
if (result != nullptr) {
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString()
<< ", Value: " << result->abstract().get() << ", " << result->abstract()->ToString();
return result;
}
MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
result = Eval(conf);
if (result == nullptr) {
@ -187,8 +181,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
trace::TraceEvalCNodeLeave();
} else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name()
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph");
}
#ifdef DEBUG
@ -280,6 +273,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
}
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
if (func == nullptr) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << maybe_func->ToString() << ".";
@ -321,28 +315,28 @@ EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const Abs
}
void AnalysisEngine::ClearEvaluatorCache() {
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : evaluators_) {
for (auto &element : evaluators_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
evaluator->evaluator_cache_map()->clear();
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
evaluator->evaluator_cache_mgr()->Clear();
}
for (auto &element : prim_constructors_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
evaluator->evaluator_cache_map()->clear();
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
evaluator->evaluator_cache_mgr()->Clear();
}
for (auto &element : prim_py_evaluators_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
evaluator->evaluator_cache_map()->clear();
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
evaluator->evaluator_cache_mgr()->Clear();
}
}
void AnalysisEngine::Clear() {
analysis_cache_.Clear();
AnalysisResultCacheMgr::GetInstance().Clear();
anfnode_config_map_.clear();
eval_trace_.clear();
evaluators_.clear();
@ -517,6 +511,11 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
if (func->tracking_id() != nullptr) {
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
}
MS_EXCEPTION_IF_NULL(func);
// protect the constructors
static std::recursive_mutex constructors_mutex;
// std::lock_guard<std::recursive_mutex> lock(constructors_mutex);
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
func->isa<abstract::FuncGraphAbstractClosure>()) {
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
@ -570,7 +569,16 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
MS_EXCEPTION_IF_NULL(eval);
return eval->Run(shared_from_this(), args_conf_list, out_conf);
}
#if !(defined _WIN32 || defined _WIN64)
static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1");
if (enable_singleThread) {
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
} else {
return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
}
#else
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
#endif
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
@ -623,17 +631,17 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
auto &eval_cache = alternate_evaluator->evaluator_cache_map();
auto eval_cache = alternate_evaluator->evaluator_cache_mgr();
const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list);
if ((!undetermined_evals.count(alt_eval_args)) &&
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
(eval_cache->find(args_spec_list) == eval_cache->end()))) {
(((!continued_evals_.count(u_eval)) && (eval_cache->GetValue(args_spec_list) != nullptr)) ||
(eval_cache->GetValue(args_spec_list) == nullptr))) {
MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
if (!has_undetermined) {
MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
*continue_flag = true;
return latest_entry;
@ -675,7 +683,7 @@ std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBa
}
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) {
if (out_specs.size() == 0) {
if (out_specs.empty()) {
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
}
@ -710,6 +718,135 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
if (abstract->isa<AbstractFunction>()) {
return true;
}
if (abstract->isa<AbstractSequeue>()) {
auto elements = abstract->cast<AbstractSequeuePtr>()->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return item->isa<AbstractFunction>(); })) {
return true;
}
}
return false;
}
EvalResultPtr ExecEvaluetor(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list,
AnfNodeConfigPtr out_conf, std::string caller, AsyncEvalResultPtr async_result_branch,
AsyncEvalResultPtr async_result_main) {
// TiggerToken_scoped_acquire tigger_token_acquire;
py::gil_scoped_acquire pyGuard;
EvalResultPtr result = nullptr;
try {
AnalysisResultCacheMgr::UpdateCaller(caller);
result = eval->Run(engine, args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(result);
async_result_branch->JoinResult(result);
async_result_main->JoinResult(result);
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->abstract()->ToString();
auto broadAbstract = result->abstract()->Broaden();
auto broadEvalResult = std::make_shared<EvalResult>(broadAbstract, nullptr);
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadEvalResult);
} catch (const std::exception &e) {
std::ostringstream oss;
trace::GetEvalStackInfo(oss);
if (!oss.str().empty()) {
MS_LOG(ERROR) << oss.str();
}
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>(oss.str()), out_conf->node());
async_result_main->JoinResult(std::make_shared<EvalResult>(abstractErrPtr, nullptr));
StaticAnalysisException::Instance().SetException();
}
return result;
}
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
// TiggerToken_scoped_release tigger_token_release;
pybind11::gil_scoped_release release;
// Wait for the switch node to finish.
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
if (eval_result == nullptr) {
MS_LOG(DEBUG) << GetInferThread() << "async : Init switch " << out_conf->ToString();
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
} else {
if (eval_result->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out."
<< "please check the code if there are recursive functions.";
}
return eval_result;
}
// Eval result of the branches and main.
AsyncEvalResultPtr asyncResult0 = std::make_shared<AsyncEvalResult>();
AsyncEvalResultPtr asyncResult1 = std::make_shared<AsyncEvalResult>();
AsyncEvalResultPtr asyncResult_main = std::make_shared<AsyncEvalResult>();
SetUndeterminedFlag(evaluators[0]);
SetUndeterminedFlag(evaluators[1]);
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
auto future0 = std::async(std::launch::async, ExecEvaluetor, evaluators[0], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult0, asyncResult_main);
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
auto future1 = std::async(std::launch::async, ExecEvaluetor, evaluators[1], shared_from_this(), args_conf_list,
out_conf, threadId, asyncResult1, asyncResult_main);
// Wait for async threads to finish.
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
<< " or " << evaluators[1]->ToString();
auto branchResult = asyncResult_main->GetResult();
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
<< "please check the code if there are recursive functions.";
}
if (branchResult->isa<AbstractError>()) {
MS_LOG(EXCEPTION) << "async " << out_conf->node()->ToString() << " threw exception.";
}
AbstractBasePtrList out_specs;
if (NeedWaitForTwoBranches(branchResult->abstract())) {
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[0]->ToString();
auto result0 = asyncResult0->GetResult();
if (result0->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << "is time out."
<< " Please check the code if there is recursive function.";
}
out_specs.push_back(result0->abstract());
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[1]->ToString();
auto result1 = asyncResult1->GetResult();
if (result1->isa<AbstractTimeOut>()) {
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << "is time out."
<< " Please check the code if there is recursive function.";
}
out_specs.push_back(result1->abstract());
} else {
if (asyncResult0->TryGetResult()) {
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[0]->ToString()
<< " value0=" << asyncResult0->GetResult()->abstract()->ToString();
out_specs.push_back(asyncResult0->GetResult()->abstract());
}
if (asyncResult1->TryGetResult()) {
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[1]->ToString()
<< " value1=" << asyncResult1->GetResult()->abstract()->ToString();
out_specs.push_back(asyncResult1->GetResult()->abstract());
}
}
return ProcessEvalResults(out_specs, out_conf->node());
}
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
@ -726,9 +863,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
for (auto eval : evaluators) {
for (const auto &eval : evaluators) {
SetUndeterminedFlag(eval);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
@ -757,11 +893,11 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
// Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->evaluator_) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
auto latest_entry_eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(latest_entry_eval_result->abstract());
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
<< " return out_spec: " << latest_entry_eval_result->abstract()->ToString();
return latest_entry_eval_result;
<< " return out_spec: " << eval_result->abstract()->ToString();
return eval_result;
}
}
}
@ -815,8 +951,9 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont
if (value->isa<Primitive>()) {
auto prim = value->cast<PrimitivePtr>();
return MakeAbstractClosure(prim, anf_node);
} else {
return value->ToAbstract();
}
return value->ToAbstract();
}
AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {

View File

@ -28,6 +28,7 @@
#include <map>
#include <set>
#include <unordered_set>
#include <mutex>
#ifdef DEBUG
#include <stack>
@ -154,19 +155,6 @@ class VirtualConfig : public Config {
AbstractBasePtr abstract_;
};
// AnalysisCache
class AnalysisCache {
public:
AnalysisCache() = default;
~AnalysisCache() = default;
void Clear() { analysis_cache_map_.clear(); }
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
private:
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> analysis_cache_map_;
};
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
using AnfNodeConfigMap =
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
@ -183,12 +171,38 @@ struct PartialAppHasher {
return h1 ^ h2;
}
};
// Should compare Args based on value other than pointer;
struct EvaluatorArgs {
EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {}
bool operator==(const EvaluatorArgs &other) const {
if (evaluator_ != other.evaluator_) {
return false;
}
if (AbstractBasePtrListDeepEqual(args_, other.args_)) {
return true;
}
return false;
}
bool operator!=(const EvaluatorArgs &other) { return !(*this == other); }
EvaluatorPtr evaluator_;
AbstractBasePtrList args_;
};
using EvalTraceRevIter = std::list<EvaluatorArgs>::reverse_iterator;
struct EvaluatorArgsHasher {
std::size_t operator()(const EvaluatorArgs &eval_args) const {
return hash_combine(std::hash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_));
}
};
struct EvaluatorArgsEqual {
bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; }
};
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
: analysis_cache_(AnalysisCache()),
prim_constructors_(prim_evaluator_map),
func_graph_manager_(func_graph_manager) {
: prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {
function_call_depth_ = 0;
function_call_max_depth_ = 0;
stack_frame_depth_ = 0;
@ -202,11 +216,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(result);
analysis_cache_.set_value(conf, result);
}
void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result);
EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
// Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
@ -218,7 +228,6 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear();
void ClearEvaluatorCache();
AnalysisCache &analysis_cache() { return analysis_cache_; }
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
}
@ -239,7 +248,6 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
AnalysisCache analysis_cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void ResetFunctionCallDepth() {
@ -249,7 +257,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
void IncreaseFunctionCallDepth() {
function_call_depth_++;
if (function_call_max_depth_ < function_call_depth_) {
function_call_max_depth_ = function_call_depth_;
function_call_max_depth_ = function_call_depth_.load();
}
}
void DecreaseFunctionCallDepth() {
@ -268,7 +276,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
void IncreaseStackFrameDepth() {
stack_frame_depth_++;
if (stack_frame_max_depth_ < stack_frame_depth_) {
stack_frame_max_depth_ = stack_frame_depth_;
stack_frame_max_depth_ = stack_frame_depth_.load();
}
}
void DecreaseStackFrameDepth() {
@ -279,39 +287,10 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
}
size_t stack_frame_depth() const { return stack_frame_depth_; }
size_t stack_frame_max_depth() const { return stack_frame_max_depth_; }
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
bool enable_recursive_eval() const { return enable_recursive_eval_; }
private:
// Should compare Args based on value other than pointer;
struct EvaluatorArgs {
EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {}
bool operator==(const EvaluatorArgs &other) const {
if (evaluator_ != other.evaluator_) {
return false;
}
if (AbstractBasePtrListDeepEqual(args_, other.args_)) {
return true;
}
return false;
}
bool operator!=(const EvaluatorArgs &other) { return !(*this == other); }
EvaluatorPtr evaluator_;
AbstractBasePtrList args_;
};
using EvalTraceRevIter = std::list<EvaluatorArgs>::reverse_iterator;
struct EvaluatorArgsHasher {
std::size_t operator()(const EvaluatorArgs &eval_args) const {
return hash_combine(std::hash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_));
}
};
struct EvaluatorArgsEqual {
bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; }
};
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
@ -323,6 +302,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_;
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
constructors_app_;
AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators.
std::list<EvaluatorArgs> eval_trace_;
@ -337,15 +317,18 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
// Record current depth of function call stack, including `stack_frame_depth_`.
size_t function_call_depth_;
size_t function_call_max_depth_;
std::atomic_long function_call_depth_;
std::atomic_long function_call_max_depth_;
// Record current depth of stack frames call.
size_t stack_frame_depth_;
size_t stack_frame_max_depth_;
std::atomic_long stack_frame_depth_;
std::atomic_long stack_frame_max_depth_;
size_t forward_count_;
std::atomic_long forward_count_;
bool enable_recursive_eval_;

View File

@ -162,6 +162,10 @@ AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) {
return AbstractBase::Broaden(config);
} else {
auto type = this->BuildType()->type_id();
if (type < kNumberTypeBegin || type > kNumberTypeEnd) {
return AbstractBase::Broaden(config);
}
return Clone();
}
}
@ -1085,6 +1089,27 @@ std::string AbstractNull::ToString() const {
return buffer.str();
}
bool AbstractTimeOut::operator==(const AbstractTimeOut &) const { return true; }
bool AbstractTimeOut::operator==(const AbstractBase &other) const {
if (&other == this) {
return true;
}
if (other.isa<AbstractTimeOut>()) {
auto other_none = static_cast<const AbstractTimeOut *>(&other);
return *this == *other_none;
} else {
return false;
}
}
std::string AbstractTimeOut::ToString() const {
std::ostringstream buffer;
buffer << "AbstractTimeOut "
<< "(Value: Null)";
return buffer.str();
}
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
bool AbstractEllipsis::operator==(const AbstractBase &other) const {

View File

@ -312,7 +312,7 @@ class AbstractTensor : public AbstractUndetermined {
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden(uint8_t config = 0) const override;
AbstractBasePtr BroadenWithShape() const;
AbstractBasePtr Join(const AbstractBasePtr &other);
AbstractBasePtr Join(const AbstractBasePtr &other) override;
bool operator==(const AbstractTensor &other) const;
bool operator==(const AbstractBase &other) const override;
std::string ToString() const override;
@ -565,6 +565,21 @@ class AbstractNull : public AbstractBase {
};
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
// the timeout state value for variable, which means the variable is not assigned because it is timeout
class AbstractTimeOut : public AbstractBase {
public:
AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
~AbstractTimeOut() override = default;
MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<TypeNull>(); }
bool operator==(const AbstractTimeOut &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTimeOut>(); }
std::string ToString() const override;
};
using AbstractTimeOutPtr = std::shared_ptr<AbstractTimeOut>;
class AbstractEllipsis : public AbstractBase {
public:
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }

View File

@ -230,7 +230,7 @@ DebugInfoPtr TraceManager::GetParseOrResolveDebugInfo() { return TraceManager::p
void TraceManager::ClearParseOrResolveDebugInfo() { TraceManager::parse_or_resolve_debug_info_ = nullptr; }
std::stack<TraceContextPtr> TraceManager::trace_context_stack_;
thread_local std::stack<TraceContextPtr> TraceManager::trace_context_stack_;
DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr;
thread_local DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr;
} // namespace mindspore

View File

@ -79,8 +79,8 @@ class TraceManager {
static void ClearParseOrResolveDebugInfo();
static DebugInfoPtr GetParseOrResolveDebugInfo();
static std::stack<TraceContextPtr> trace_context_stack_;
static DebugInfoPtr parse_or_resolve_debug_info_;
thread_local static std::stack<TraceContextPtr> trace_context_stack_;
thread_local static DebugInfoPtr parse_or_resolve_debug_info_;
};
class TraceGuard {

View File

@ -58,6 +58,45 @@ class MsException {
ExceptionListener *listener_{nullptr};
std::exception_ptr exception_ptr_{nullptr};
};
class StaticAnalysisException {
public:
static StaticAnalysisException &Instance() {
static StaticAnalysisException instance;
return instance;
}
void ClearException() { exception_ptr_ = nullptr; }
bool HasException() { return exception_ptr_ != nullptr; }
void SetException() {
if (exception_ptr_ != nullptr) {
return;
}
exception_ptr_ = std::current_exception();
}
void SetAndRethrowException() {
SetException();
std::rethrow_exception(std::current_exception());
}
void CheckException() {
if (exception_ptr_ != nullptr) {
auto tmp_exception_ptr = exception_ptr_;
exception_ptr_ = nullptr;
std::rethrow_exception(tmp_exception_ptr);
}
}
private:
StaticAnalysisException() = default;
~StaticAnalysisException() = default;
DISABLE_COPY_AND_ASSIGN(StaticAnalysisException)
std::exception_ptr exception_ptr_{nullptr};
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_

View File

@ -18,6 +18,7 @@
#include "common/common_test.h"
#include "pybind11/pybind11.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "abstract/utils.h"
#include "pipeline/jit/static_analysis/prim.h"
@ -37,6 +38,12 @@ class TestAbstract : public UT::Common {
};
TEST_F(TestAbstract, TestParseDataClass) {
// Check initialization before callback to Python.
if (Py_IsInitialized() == 0) {
Py_Initialize();
}
PyEval_InitThreads();
py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo");
ClassPtr cls_ptr = parse::ParseDataClass(fn);