forked from mindspore-Ecosystem/mindspore
infer_optv3:infer optimize to find exit of branches and resolve the
endless issue.
This commit is contained in:
parent
45902803b2
commit
fc8ec7fc49
|
@ -34,6 +34,7 @@
|
|||
#include "debug/anf_ir_utils.h"
|
||||
#include "debug/common.h"
|
||||
#include "pipeline/jit/static_analysis/evaluator.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
|
@ -177,7 +178,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
|
|||
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
auto cfg = engine_->MakeConfig(node, cur_ctx_);
|
||||
auto ret = engine_->analysis_cache().GetValue(cfg);
|
||||
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
|
||||
if (ret == nullptr) {
|
||||
return "Undefined";
|
||||
}
|
||||
|
@ -190,7 +191,7 @@ AbstractBasePtr AnalyzedFuncGraphExporter::GetNodeAbstract(const AnfNodePtr &nod
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
auto cfg = engine_->MakeConfig(node, cur_ctx_);
|
||||
auto ret = engine_->analysis_cache().GetValue(cfg);
|
||||
auto ret = abstract::AnalysisResultCacheMgr::GetInstance().GetValue(cfg);
|
||||
return ret == nullptr ? nullptr : ret->abstract();
|
||||
}
|
||||
|
||||
|
@ -548,9 +549,9 @@ void GetEvalStackInfo(std::ostringstream &oss) {
|
|||
}
|
||||
|
||||
// trace the graph evaluator stack
|
||||
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
|
||||
thread_local static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
|
||||
// trace the cnode infer debug info
|
||||
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
|
||||
thread_local static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
|
||||
|
||||
void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
|
||||
if (eval == nullptr) {
|
||||
|
|
|
@ -106,7 +106,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
prim::kPrimMirrorMiniStep);
|
||||
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
|
||||
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
|
||||
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
|
||||
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
|
||||
check_bprop_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ =
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "pipeline/jit/pass.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
|
@ -728,10 +729,12 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
|
|||
MS_LOG(DEBUG) << PrintArgs(args);
|
||||
ret_value = CompileInner(obj, args, phase, use_vm);
|
||||
} catch (const py::error_already_set &ex) {
|
||||
// print function call stack info before release
|
||||
std::string exception_info = GetCompileExceptionInfo();
|
||||
if (!exception_info.empty()) {
|
||||
MS_LOG(ERROR) << exception_info;
|
||||
if (!StaticAnalysisException::Instance().HasException()) {
|
||||
// print function call stack info before release
|
||||
std::string exception_info = GetCompileExceptionInfo();
|
||||
if (!exception_info.empty()) {
|
||||
MS_LOG(ERROR) << exception_info;
|
||||
}
|
||||
}
|
||||
ReleaseResource(phase);
|
||||
|
||||
|
@ -1281,6 +1284,7 @@ void ClearResAtexit() {
|
|||
|
||||
ReleaseGeTsd();
|
||||
parse::python_adapter::ResetPythonScope();
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
Debugger::GetInstance()->Reset();
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include <chrono>
|
||||
#include "utils/symbolic.h"
|
||||
#include "debug/common.h"
|
||||
#include "pipeline/jit/base.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
EvalResultPtr AsyncEvalResult::TryGetResult(int ms) {
|
||||
if (result_ != nullptr || ms == 0) {
|
||||
return result_;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto time = std::chrono::microseconds(ms);
|
||||
// Wait for ms.
|
||||
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
return result_;
|
||||
}
|
||||
|
||||
EvalResultPtr AsyncEvalResult::GetResult() {
|
||||
if (result_ != nullptr) {
|
||||
return result_;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto time = std::chrono::seconds(kInferTimeout);
|
||||
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
return result_;
|
||||
}
|
||||
|
||||
std::string AsyncEvalResult::ToString() {
|
||||
std::ostringstream buffer;
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
buffer << (result_ == nullptr ? "NOT SET" : result_->abstract()->ToString());
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
void AsyncEvalResult::JoinResult(const EvalResultPtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
}
|
||||
condition_var_.notify_all();
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::Clear() {
|
||||
cache_.clear();
|
||||
switch_cache_.clear();
|
||||
todo_.clear();
|
||||
}
|
||||
|
||||
AnalysisResultCacheMgr &AnalysisResultCacheMgr::GetInstance() {
|
||||
static AnalysisResultCacheMgr instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::DumpCache(const std::string &filename) {
|
||||
auto path = pipeline::GetSaveGraphsPathName(Common::AddId(filename, ".cache"));
|
||||
|
||||
auto realpath = Common::GetRealPath(path);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed. path=" << path;
|
||||
return;
|
||||
}
|
||||
ChangeFileMode(realpath.value(), S_IRWXU);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!";
|
||||
return;
|
||||
}
|
||||
fout << cache_.dump();
|
||||
fout.close();
|
||||
// Set file mode to read only by user
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
}
|
||||
|
||||
thread_local static std::string local_threadid;
|
||||
void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
|
||||
std::ostringstream buffer;
|
||||
buffer << caller << "." << std::this_thread::get_id();
|
||||
local_threadid = buffer.str();
|
||||
}
|
||||
std::mutex AnalysisResultCacheMgr::tiggerToken_;
|
||||
std::string &AnalysisResultCacheMgr::GetThreadid() { return local_threadid; }
|
||||
|
||||
void AnalysisResultCacheMgr::PushTowait(const std::shared_future<EvalResultPtr> &future0,
|
||||
const std::shared_future<EvalResultPtr> &future1) {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
waiting_.push_back(future0);
|
||||
waiting_.push_back(future1);
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
todo_.push_back(conf);
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
|
||||
if (async_eval_result == nullptr) {
|
||||
async_eval_result = std::make_shared<AsyncEvalResult>();
|
||||
switch_cache_.set(conf, async_eval_result);
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
|
||||
// Conf has been visited and set value.
|
||||
if (async_eval_result != nullptr) {
|
||||
// Maybe blocked for waiting. AsyncEvalResult maybe null, if time out.
|
||||
auto result = async_eval_result->GetResult();
|
||||
if (result == nullptr) {
|
||||
result = std::make_shared<EvalResult>(std::make_shared<AbstractTimeOut>(), nullptr);
|
||||
MS_LOG(ERROR) << "AsyncEvalResult for NodeConfig " << conf->ToString() << " is nullptr, maybe timeout.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr arg) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
if (arg == nullptr || arg->abstract() == nullptr) {
|
||||
MS_LOG(WARNING) << conf->ToString() << " value is nullptr";
|
||||
}
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
|
||||
if (async_eval_result == nullptr) {
|
||||
MS_LOG(EXCEPTION) << conf->ToString() << " Not key.";
|
||||
async_eval_result = std::make_shared<AsyncEvalResult>();
|
||||
async_eval_result->JoinResult(arg);
|
||||
switch_cache_.set(conf, async_eval_result);
|
||||
} else {
|
||||
auto ab1 = async_eval_result->TryGetResult();
|
||||
AbstractBasePtrList absList;
|
||||
if (ab1 != nullptr) {
|
||||
absList.push_back(arg->abstract());
|
||||
absList.push_back(ab1->abstract());
|
||||
// Join two branches's result
|
||||
auto joined_spec = AbstractJoin(absList);
|
||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
||||
auto joined_result = std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
async_eval_result->JoinResult(joined_result);
|
||||
if (joined_result != ab1) {
|
||||
PushTodo(conf);
|
||||
}
|
||||
} else {
|
||||
async_eval_result->JoinResult(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::Todo() {
|
||||
while (true) {
|
||||
AnfNodeConfigPtr conf;
|
||||
lock_.lock();
|
||||
if (!todo_.empty()) {
|
||||
conf = todo_.front();
|
||||
} else {
|
||||
lock_.unlock();
|
||||
break;
|
||||
}
|
||||
todo_.pop_front();
|
||||
lock_.unlock();
|
||||
if (!(*GetValue(conf)->abstract() == *GetSwitchValue(conf)->abstract())) {
|
||||
MS_LOG(WARNING) << " Switch Value is not eq. "
|
||||
<< " switchCache: " << GetSwitchValue(conf)->abstract()->ToString()
|
||||
<< " globleCache: " << GetValue(conf)->abstract()->ToString() << "\t\tConf: " << conf->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::Wait() {
|
||||
while (true) {
|
||||
std::shared_future<EvalResultPtr> future;
|
||||
lock_.lock();
|
||||
if (!waiting_.empty()) {
|
||||
future = std::move(waiting_.front());
|
||||
} else {
|
||||
lock_.unlock();
|
||||
break;
|
||||
}
|
||||
|
||||
waiting_.pop_front();
|
||||
lock_.unlock();
|
||||
future.wait();
|
||||
}
|
||||
if (IS_OUTPUT_ON(DEBUG)) {
|
||||
Todo();
|
||||
}
|
||||
}
|
||||
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "(";
|
||||
for (const auto &item : args_spec_list) {
|
||||
buffer << item->ToString() << " # ";
|
||||
}
|
||||
buffer << " )";
|
||||
return buffer.str();
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,192 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <fstream>
|
||||
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
constexpr size_t kInferTimeout = 60;
|
||||
|
||||
template <typename KeyType, typename ValueType, typename CacheType>
|
||||
class MultiThreadCache {
|
||||
public:
|
||||
using iterator = typename CacheType::iterator;
|
||||
using const_iterator = typename CacheType::const_iterator;
|
||||
|
||||
ValueType get(const KeyType &key) {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
auto it = cache_.find(key);
|
||||
if (it != cache_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void set(const KeyType &key, const ValueType &data) {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
cache_[key] = data;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
size_t size() { return cache_.size(); }
|
||||
|
||||
bool empty() { return size() == 0; }
|
||||
|
||||
std::string dump() {
|
||||
std::ostringstream buf;
|
||||
for (auto &item : cache_) {
|
||||
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
|
||||
}
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
iterator begin() { return cache_.begin(); }
|
||||
iterator end() { return cache_.end(); }
|
||||
|
||||
const_iterator begin() const { return cache_.cbegin(); }
|
||||
const_iterator end() const { return cache_.cend(); }
|
||||
|
||||
const_iterator cbegin() const { return cache_.cbegin(); }
|
||||
const_iterator cend() const { return cache_.cend(); }
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
CacheType cache_;
|
||||
};
|
||||
|
||||
class AsyncEvalResult;
|
||||
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
|
||||
|
||||
using EvaluatorCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvalResultCache = MultiThreadCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
||||
|
||||
class AsyncEvalResult {
|
||||
public:
|
||||
AsyncEvalResult() = default;
|
||||
~AsyncEvalResult() = default;
|
||||
// wait
|
||||
EvalResultPtr GetResult();
|
||||
// not wait
|
||||
EvalResultPtr TryGetResult(int ms = 0);
|
||||
void JoinResult(const EvalResultPtr &result);
|
||||
std::string ToString();
|
||||
|
||||
private:
|
||||
EvalResultPtr result_{nullptr};
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
};
|
||||
|
||||
class EvaluatorCacheMgr {
|
||||
public:
|
||||
EvaluatorCacheMgr() = default;
|
||||
~EvaluatorCacheMgr() = default;
|
||||
|
||||
void Clear() { eval_result_cache_.clear(); }
|
||||
EvalResultCache &GetCache() { return eval_result_cache_; }
|
||||
EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
|
||||
void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
|
||||
size_t GetSize() { return eval_result_cache_.size(); }
|
||||
|
||||
private:
|
||||
EvalResultCache eval_result_cache_;
|
||||
};
|
||||
|
||||
// AnalysisCache
|
||||
class AnalysisResultCacheMgr {
|
||||
public:
|
||||
~AnalysisResultCacheMgr() = default;
|
||||
AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
|
||||
AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
|
||||
static AnalysisResultCacheMgr &GetInstance();
|
||||
static std::mutex &tiggerToken() { return tiggerToken_; }
|
||||
void Clear();
|
||||
|
||||
using AnalysisConfigAsyncResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, AsyncEvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigAsyncResultCache =
|
||||
MultiThreadCache<AnfNodeConfigPtr, AsyncEvalResultPtr, AnalysisConfigAsyncResultMap>;
|
||||
|
||||
using AnalysisConfigResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigResultCache = MultiThreadCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
|
||||
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
|
||||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||
|
||||
// Dump all the conf and result
|
||||
void DumpCache(const std::string &filename);
|
||||
// Wait for async Eval(conf) to finish.
|
||||
void Wait();
|
||||
void PushTowait(const std::shared_future<EvalResultPtr> &future0, const std::shared_future<EvalResultPtr> &future1);
|
||||
void PushTodo(const AnfNodeConfigPtr &conf);
|
||||
void Todo();
|
||||
static void UpdateCaller(const std::string &caller);
|
||||
static std::string &GetThreadid();
|
||||
|
||||
void InitSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr GetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
void SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr vale);
|
||||
|
||||
private:
|
||||
AnalysisResultCacheMgr() = default;
|
||||
|
||||
static std::mutex tiggerToken_;
|
||||
std::recursive_mutex lock_;
|
||||
std::list<std::shared_future<EvalResultPtr>> waiting_;
|
||||
std::list<AnfNodeConfigPtr> todo_;
|
||||
|
||||
AnalysisConfigResultCache cache_;
|
||||
AnalysisConfigAsyncResultCache switch_cache_;
|
||||
};
|
||||
|
||||
class TiggerToken_scoped_release {
|
||||
public:
|
||||
TiggerToken_scoped_release() { AnalysisResultCacheMgr::tiggerToken().unlock(); }
|
||||
~TiggerToken_scoped_release() { AnalysisResultCacheMgr::tiggerToken().lock(); }
|
||||
};
|
||||
class TiggerToken_scoped_acquire {
|
||||
public:
|
||||
TiggerToken_scoped_acquire() { AnalysisResultCacheMgr::tiggerToken().lock(); }
|
||||
~TiggerToken_scoped_acquire() { AnalysisResultCacheMgr::tiggerToken().unlock(); }
|
||||
};
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisResultCacheMgr::GetThreadid() + ":"; }
|
||||
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
|
|
@ -25,6 +25,7 @@
|
|||
#include "debug/trace.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pipeline/jit/static_analysis/stack_frame.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -80,10 +81,9 @@ void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, co
|
|||
// Increase & Check the func graph call depth.
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
engine->IncreaseStackFrameDepth();
|
||||
if (engine->function_call_depth() - engine->stack_frame_depth() >
|
||||
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
|
||||
if (engine->function_call_depth() > max_depth) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
|
||||
<< ", (function call depth: " << engine->function_call_depth()
|
||||
<< ", simulate call depth: " << engine->stack_frame_depth()
|
||||
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
||||
|
@ -116,7 +116,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_);
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
|
||||
stack_frames.push(current_stack_frame);
|
||||
while (1) {
|
||||
while (true) {
|
||||
current_stack_frame = stack_frames.top();
|
||||
if (current_stack_frame->Done()) {
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
|
@ -135,8 +135,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
// Save func graph eval result for specialize.
|
||||
auto evaluator = current_stack_frame->evaluator();
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
EvaluatorCacheMap &evaluator_cache_map = *evaluator->evaluator_cache_map();
|
||||
evaluator_cache_map[current_stack_frame->args_abs_list()] = eval_result;
|
||||
evaluator->evaluator_cache_mgr()->SetValue(current_stack_frame->args_abs_list(), eval_result);
|
||||
|
||||
// Leave current func graph.
|
||||
LeaveStackFrame(engine, current_stack_frame);
|
||||
|
@ -167,7 +166,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
|
||||
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
|
||||
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
@ -180,20 +179,26 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
|
|||
<< ", node_conf: " << node_conf->ToString();
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
res_base = node_eval_result->abstract();
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << res_base->ToString();
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << res_base->ToString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
return res_base;
|
||||
}
|
||||
|
||||
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
|
||||
if (eval_result != nullptr) {
|
||||
MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong.";
|
||||
return eval_result;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << ToString() << " entered first.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
engine->IncreaseFunctionCallDepth();
|
||||
if (engine->function_call_depth() - engine->stack_frame_depth() >
|
||||
MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
|
||||
if (engine->function_call_depth() > max_depth) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
|
||||
<< ", (function call depth: " << engine->function_call_depth()
|
||||
<< ", simulate call depth: " << engine->stack_frame_depth()
|
||||
<< "), please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
||||
|
@ -212,6 +217,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< args_abs_list.size() << ".";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(parent_context_);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
|
||||
if (parent_context_->func_graph() != nullptr) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisResultCacheMgr::GetThreadid() << ":"
|
||||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
|
||||
<< fg->ToString() << "();";
|
||||
}
|
||||
|
||||
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
||||
const auto ¶meters = fg->parameters();
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
|
@ -219,6 +231,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
const auto &node = parameters[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, context_);
|
||||
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
|
||||
|
@ -238,11 +251,14 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
res_base = std::make_shared<AbstractUndetermined>();
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << res_base->ToString();
|
||||
engine->DecreaseFunctionCallDepth();
|
||||
MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
|
||||
<< "), leave, function call depth: " << engine->function_call_depth() << " - "
|
||||
<< engine->stack_frame_depth();
|
||||
return std::make_shared<EvalResult>(res_base, nullptr);
|
||||
auto res = std::make_shared<EvalResult>(res_base, nullptr);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
|
@ -280,6 +296,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
return args_spec_list;
|
||||
}
|
||||
|
||||
if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
|
||||
if (parent_context_) {
|
||||
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
|
||||
|
@ -307,7 +324,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
return joined_args_spec_list_1;
|
||||
}
|
||||
}
|
||||
if (trace_.size() != 0) {
|
||||
if (!trace_.empty()) {
|
||||
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
||||
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back());
|
||||
// Join the last eval arguments and current arguments to check if there are loop variant.
|
||||
|
@ -336,27 +353,26 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
|
||||
FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||
auto iter = func_graph_cache_.find(args_spec_list);
|
||||
FuncGraphPtr ret = nullptr;
|
||||
FuncGraphPtr res;
|
||||
if (iter == func_graph_cache_.end()) {
|
||||
auto fg = func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
TraceGuard guard(std::make_shared<TraceEvaluatorGenGraph>(fg->debug_info()));
|
||||
FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list);
|
||||
func_graph_cache_[args_spec_list] = generated_graph;
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
engine->func_graph_manager()->AddFuncGraph(generated_graph);
|
||||
ret = generated_graph;
|
||||
res = generated_graph;
|
||||
} else {
|
||||
ret = iter->second;
|
||||
res = iter->second;
|
||||
}
|
||||
|
||||
// For the top graph, if it is replaced by generated graph, update the top graph to the new one.
|
||||
if (parse::Parser::GetTopFuncGraph() == func_graph()) {
|
||||
if (ret != func_graph()) {
|
||||
parse::Parser::UpdateTopFuncGraph(ret);
|
||||
if (res != func_graph()) {
|
||||
parse::Parser::UpdateTopFuncGraph(res);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
return res;
|
||||
}
|
||||
|
||||
FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||
|
@ -366,7 +382,7 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
|
|||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(meta_func_graph_);
|
||||
FuncGraphPtr generated_func_graph = nullptr;
|
||||
FuncGraphPtr generated_func_graph;
|
||||
if (this->bound_node() != nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(bound_node()->debug_info()));
|
||||
generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list);
|
||||
|
@ -381,8 +397,18 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
|
|||
return cloned_func_graph;
|
||||
}
|
||||
|
||||
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) {
|
||||
const std::string &evaluator_name = ToString();
|
||||
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
const string evaluator_name = ToString();
|
||||
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_loc_, std::try_to_lock);
|
||||
if (!eval_lock.owns_lock()) {
|
||||
auto py_tstate = PyEval_SaveThread();
|
||||
eval_lock.try_lock_for(std::chrono::seconds(kInferTimeout));
|
||||
PyEval_RestoreThread(py_tstate);
|
||||
if (!eval_lock.owns_lock()) {
|
||||
MS_LOG(EXCEPTION) << "It is timeout to run " << ToString();
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
|
@ -394,30 +420,30 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
|
|||
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
||||
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
|
||||
MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
|
||||
auto iter = evaluator_cache_map_->find(args_spec_list);
|
||||
if (iter == evaluator_cache_map_->end()) {
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result == nullptr) {
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
||||
EvalResultPtr ret = Eval(engine, args_spec_list);
|
||||
if (ret->abstract() == nullptr) {
|
||||
EvalResultPtr res = Eval(engine, args_spec_list);
|
||||
if (res->abstract() == nullptr) {
|
||||
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
||||
}
|
||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
|
||||
(*evaluator_cache_map_)[args_spec_list] = ret;
|
||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << res->abstract()->ToString() << ".";
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return ret;
|
||||
return res;
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
MS_EXCEPTION_IF_NULL(iter->second->abstract());
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << ".";
|
||||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << eval_result->abstract()->ToString() << ".";
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return iter->second;
|
||||
return eval_result;
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr) {
|
||||
const AnfNodeConfigPtr &) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
|
@ -432,58 +458,57 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
}
|
||||
return abstract;
|
||||
});
|
||||
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
|
||||
return ret;
|
||||
return EvalPrim(engine, args_spec_list);
|
||||
}
|
||||
|
||||
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
if (args_conf_list.size() == 0) {
|
||||
if (args_conf_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
||||
}
|
||||
EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
||||
EvalResultPtr res = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
||||
// No need to cache.
|
||||
return ret;
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||
EvalResultPtr ret = EvalPrim(args_conf_list);
|
||||
return ret;
|
||||
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &) {
|
||||
return EvalPrim(args_conf_list);
|
||||
}
|
||||
|
||||
EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
||||
EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
||||
// Don't lookup from cache, as different out_conf with same node but different context
|
||||
// may add different entry to anfnode_config_map_, like getattr primitive.
|
||||
(*evaluator_cache_map_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
|
||||
auto iter = evaluator_cache_map_->find(args_spec_list);
|
||||
if (iter != evaluator_cache_map_->end()) {
|
||||
return iter->second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result != nullptr) {
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
ConfigPtrList partial_args_conf_list;
|
||||
|
@ -493,23 +518,22 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
|
|||
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
|
||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
||||
|
||||
(*evaluator_cache_map_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
|
||||
auto iter = evaluator_cache_map_->find(args_spec_list);
|
||||
if (iter != evaluator_cache_map_->end()) {
|
||||
return iter->second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result != nullptr) {
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
// Call the original evaluator, get the result: y = f(x)
|
||||
|
@ -536,9 +560,9 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
// J(f)(J(x)) return a tuple (y, bprop_f)
|
||||
AbstractBasePtrList jargs = {result->abstract(), bprop};
|
||||
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
|
||||
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
|
||||
(*evaluator_cache_map_)[args_spec_list] = infer_reuslt;
|
||||
return infer_reuslt;
|
||||
auto res = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
||||
|
@ -553,5 +577,16 @@ EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrLis
|
|||
}
|
||||
return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
auto result = this->Run(engine, args_conf_list, out_conf);
|
||||
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
pybind11::gil_scoped_release release;
|
||||
AnalysisResultCacheMgr::GetInstance().Wait();
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,23 +26,22 @@
|
|||
#include <stack>
|
||||
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using EvaluatorCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
|
||||
|
||||
using EvaluatorCacheMgrPtr = std::shared_ptr<EvaluatorCacheMgr>;
|
||||
using EvaluatorAttrMap =
|
||||
std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
|
||||
using EvaluatorAttrCache = MultiThreadCache<AbstractBasePtrList, AttrValueMapPtr, EvaluatorAttrMap>;
|
||||
using EvaluatorAttrCachePtr = std::shared_ptr<EvaluatorAttrCache>;
|
||||
|
||||
class Evaluator : public Base {
|
||||
public:
|
||||
explicit Evaluator(const std::string &id)
|
||||
: evaluator_cache_map_(std::make_shared<EvaluatorCacheMap>()),
|
||||
attr_cache_(std::make_shared<EvaluatorAttrMap>()),
|
||||
: evaluator_cache_mgr_(std::make_shared<EvaluatorCacheMgr>()),
|
||||
attr_cache_(std::make_shared<EvaluatorAttrCache>()),
|
||||
identifier_(id) {}
|
||||
~Evaluator() override = default;
|
||||
MS_DECLARE_PARENT(Evaluator, Base);
|
||||
|
@ -50,9 +49,12 @@ class Evaluator : public Base {
|
|||
// difference between Run() and Eval():
|
||||
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
|
||||
// Run() will modify cache_ member, so it cannot marked as const;
|
||||
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
|
||||
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf);
|
||||
|
||||
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
virtual EvalResultPtr SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf);
|
||||
|
||||
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
|
||||
|
||||
|
@ -87,14 +89,16 @@ class Evaluator : public Base {
|
|||
|
||||
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
|
||||
|
||||
EvaluatorCacheMapPtr &evaluator_cache_map() { return evaluator_cache_map_; }
|
||||
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr() const { return evaluator_cache_mgr_; }
|
||||
EvaluatorAttrCachePtr attr_cache() const { return attr_cache_; }
|
||||
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr_;
|
||||
EvaluatorAttrCachePtr attr_cache_;
|
||||
|
||||
EvaluatorCacheMapPtr evaluator_cache_map_;
|
||||
EvaluatorAttrMapPtr attr_cache_;
|
||||
std::string identifier_;
|
||||
|
||||
AnfNodeWeakPtr bound_node_;
|
||||
std::recursive_timed_mutex eval_loc_;
|
||||
};
|
||||
|
||||
class PrimEvaluator : public Evaluator {
|
||||
|
@ -112,7 +116,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
|
|||
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~TrivialPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
};
|
||||
|
||||
|
@ -121,7 +126,8 @@ class TransitionPrimEvaluator : public PrimEvaluator {
|
|||
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~TransitionPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
// Parameter in_conf0 : the first element in args_conf_list;
|
||||
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
|
||||
|
@ -132,7 +138,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
|
|||
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~SymbolicPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) final;
|
||||
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||
};
|
||||
|
||||
|
@ -173,7 +180,8 @@ class TrackedEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -211,9 +219,9 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
// Add functions for stack frame routine.
|
||||
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame);
|
||||
void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame);
|
||||
static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame);
|
||||
static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame);
|
||||
|
||||
AnalysisContextPtr context_;
|
||||
};
|
||||
|
@ -287,7 +295,8 @@ class PartialAppEvaluator : public Evaluator {
|
|||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -333,7 +342,8 @@ class JEvaluator : public Evaluator {
|
|||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -51,7 +51,7 @@ std::unordered_set<std::string> prims_to_skip_undetermined_infer{
|
|||
"MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
|
||||
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
|
||||
|
@ -136,7 +136,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
|
|||
}
|
||||
|
||||
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
}
|
||||
|
||||
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
|
@ -603,7 +603,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
return ret_abstract;
|
||||
}
|
||||
}
|
||||
|
||||
if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
|
||||
return EvalPyCheckPrim(engine, args);
|
||||
}
|
||||
|
@ -640,10 +639,12 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
|||
}
|
||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||
|
||||
const auto &iter = evaluator_cache_map_->find(args);
|
||||
if (iter != evaluator_cache_map_->end()) {
|
||||
return iter->second;
|
||||
const auto eval_result = evaluator_cache_mgr_->GetValue(args);
|
||||
if (eval_result != nullptr) {
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
auto py_args = PreparePyInputs(prim_py_, args);
|
||||
prim_py_->BeginRecordAddAttr();
|
||||
py::dict output = prim_py_->RunInfer(py_args);
|
||||
|
@ -653,7 +654,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
|||
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
||||
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
||||
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||
(*evaluator_cache_map_)[args] = infer_result;
|
||||
evaluator_cache_mgr_->SetValue(args, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
|
@ -1016,6 +1017,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
return nullptr;
|
||||
}
|
||||
AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
|
||||
|
||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||
if (ref_abs == nullptr) {
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||
|
@ -1082,7 +1084,7 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
// don't lookup from cache, as different out_conf with same node but different context
|
||||
// may add different entry to anfnode_config_map, like getattr primitive;
|
||||
(*evaluator_cache_map_)[args_spec_list] = ret;
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, ret);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
@ -1169,7 +1171,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
|
||||
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
|
||||
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
(*evaluator_cache_map_)[args_spec_list] = infer_result;
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
|
@ -1197,7 +1199,7 @@ class PartialEvaluator : public Evaluator {
|
|||
PartialEvaluator() : Evaluator("PartialEvaluator") {}
|
||||
~PartialEvaluator() override = default;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf = nullptr) override {
|
||||
const AnfNodeConfigPtr &out_conf = nullptr) override {
|
||||
if (args_conf_list.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
|
||||
}
|
||||
|
@ -1214,7 +1216,7 @@ class PartialEvaluator : public Evaluator {
|
|||
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
||||
<< " as func is: " << arg0_value->ToString();
|
||||
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
(*evaluator_cache_map_)[args_spec_list] = eval_result;
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
|
||||
return eval_result;
|
||||
}
|
||||
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
||||
|
@ -1248,7 +1250,7 @@ class PartialEvaluator : public Evaluator {
|
|||
|
||||
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
|
||||
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
(*evaluator_cache_map_)[args_spec_list] = eval_result;
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class DoSignatureEvaluator : public Evaluator {
|
|||
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
|
||||
~DoSignatureEvaluator() override = default;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
const AnfNodeConfigPtr &out_config = nullptr) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
|
@ -86,7 +86,7 @@ class UnpackGraphEvaluator : public Evaluator {
|
|||
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
|
||||
~UnpackGraphEvaluator() override = default;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
const AnfNodeConfigPtr &out_config = nullptr) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
|
@ -102,7 +102,7 @@ class MixedPrecisionCastEvaluator : public Evaluator {
|
|||
: Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {}
|
||||
~MixedPrecisionCastEvaluator() override = default;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
const AnfNodeConfigPtr &out_config = nullptr) override;
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
|
|
|
@ -494,11 +494,11 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
|
|||
return wrapped_node;
|
||||
}
|
||||
|
||||
const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
|
||||
const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
|
||||
auto cache_iter = evalcaches_.find(eval);
|
||||
if (cache_iter == evalcaches_.end()) {
|
||||
evalcaches_[eval] = eval->evaluator_cache_map();
|
||||
return eval->evaluator_cache_map();
|
||||
evalcaches_[eval] = eval->evaluator_cache_mgr();
|
||||
return eval->evaluator_cache_mgr();
|
||||
}
|
||||
return cache_iter->second;
|
||||
}
|
||||
|
@ -509,7 +509,8 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
|
||||
EvalResultPtr ret = nullptr;
|
||||
AbstractBasePtrList broaded_argvals;
|
||||
for (auto &argvals_map : *evalcaches_[eval]) {
|
||||
EvalResultCache &cache = evalcaches_[eval]->GetCache();
|
||||
for (auto &argvals_map : cache) {
|
||||
auto argvals = argvals_map.first;
|
||||
broaded_argvals.clear();
|
||||
|
||||
|
@ -524,11 +525,10 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
(void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
|
||||
[](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
|
||||
|
||||
// if broaden return null
|
||||
ret = eval->Run(engine_, args_conf_list, nullptr);
|
||||
EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
|
||||
|
||||
(*real)[broaded_argvals] = ret;
|
||||
// If broaden return null
|
||||
ret = eval->SingleRun(engine_, args_conf_list, nullptr);
|
||||
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
|
||||
real->SetValue(broaded_argvals, ret);
|
||||
evalcaches_[eval] = real;
|
||||
return std::make_pair(broaded_argvals, ret->abstract());
|
||||
} else {
|
||||
|
@ -615,11 +615,12 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
|
||||
void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) {
|
||||
MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
|
||||
int64_t i = 0;
|
||||
for (const auto &item : evaluator_cache_map) {
|
||||
MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
|
||||
const EvalResultCache &map = evaluator_cache_mgr->GetCache();
|
||||
for (const auto &item : map) {
|
||||
MS_LOG(DEBUG) << "evaluator_cache[" << i++ << "]: " << item.first;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -650,24 +651,24 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
MS_EXCEPTION_IF_NULL(eval);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
||||
EvaluatorCacheMap &evaluator_cache_map = *eval->evaluator_cache_map();
|
||||
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
|
||||
*result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
|
||||
EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
|
||||
auto data = evaluator_cache_mgr->GetValue(argvals);
|
||||
if (data != nullptr) {
|
||||
*result = std::make_pair(argvals, data->abstract());
|
||||
return kSpecializeSuccess;
|
||||
}
|
||||
DumpEvaluatorCache(evaluator_cache_map, argvals);
|
||||
DumpEvaluatorCache(evaluator_cache_mgr, argvals);
|
||||
|
||||
const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
|
||||
MS_EXCEPTION_IF_NULL(choices);
|
||||
|
||||
if (choices->count(argvals)) {
|
||||
*result = std::make_pair(argvals, (*choices)[argvals]->abstract());
|
||||
MS_EXCEPTION_IF_NULL(GetEvalCache(eval));
|
||||
EvalResultCache &choices = GetEvalCache(eval)->GetCache();
|
||||
if (choices.get(argvals) != nullptr) {
|
||||
*result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices->size() == 1) {
|
||||
} else if (choices.size() == 1) {
|
||||
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
|
||||
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
|
||||
*result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices->empty()) {
|
||||
} else if (choices.empty()) {
|
||||
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
|
||||
<< func->type_name();
|
||||
return kSpecializeFindUniqueArgvalDead;
|
||||
|
|
|
@ -91,7 +91,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_node_;
|
||||
std::vector<AnfNodePtr> todo_;
|
||||
std::unordered_set<AnfNodePtr> marked_;
|
||||
std::unordered_map<EvaluatorPtr, EvaluatorCacheMapPtr> evalcaches_;
|
||||
std::unordered_map<EvaluatorPtr, EvaluatorCacheMgrPtr> evalcaches_;
|
||||
|
||||
void FirstPass();
|
||||
void SecondPass();
|
||||
|
@ -127,7 +127,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
const AbstractBasePtrList &argvals,
|
||||
std::pair<AbstractBasePtrList, AbstractBasePtr> *result);
|
||||
// Get cache, it may be eval's cache or cache built from broaded argument values.
|
||||
const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval);
|
||||
const EvaluatorCacheMgrPtr GetEvalCache(const EvaluatorPtr &eval);
|
||||
// Try to build unique argvals from the broaded arg vals if it is unique.
|
||||
std::pair<AbstractBasePtrList, AbstractBasePtr> BuildFromBroadedArgsVal(const EvaluatorPtr &eval);
|
||||
void UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "pipeline/jit/static_analysis/stack_frame.h"
|
||||
#include "debug/trace.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -66,8 +67,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
AbstractBasePtrList args_abs_list = GenerateArgsAbsList(engine, evaluator, current_cnode);
|
||||
|
||||
// Check if already evaluated before.
|
||||
EvaluatorCacheMap &evaluator_cache_map = *evaluator->evaluator_cache_map();
|
||||
if (evaluator_cache_map.find(args_abs_list) != evaluator_cache_map.end()) {
|
||||
if (evaluator->evaluator_cache_mgr()->GetValue(args_abs_list) != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -128,6 +128,8 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
|
|||
<< ", current_context_: " << current_context_->ToString();
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(current_node, current_context_);
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString()
|
||||
<< ") = " << node_eval_result->abstract()->ToString();
|
||||
return node_eval_result;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,20 +17,21 @@
|
|||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_exception.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/static_analysis/evaluator.h"
|
||||
#include "debug/trace.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -39,12 +40,9 @@ bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) {
|
|||
auto v = arg_spec->GetValueTrack();
|
||||
if (v->isa<SymbolicKeyInstance>()) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) {
|
||||
|
@ -54,36 +52,6 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
|
||||
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
|
||||
<< ", Pointer: " << result->abstract().get();
|
||||
analysis_cache_map_[conf] = result;
|
||||
|
||||
// Set intermediate abstract value.
|
||||
if (IsIntermediateAbstract(result->abstract())) {
|
||||
if (conf->node()->intermediate_abstract() == nullptr) {
|
||||
conf->node()->set_intermediate_abstract(result->abstract());
|
||||
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
|
||||
} else {
|
||||
auto old_spec = conf->node()->intermediate_abstract();
|
||||
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
|
||||
conf->node()->set_intermediate_abstract(joined_spec);
|
||||
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
|
||||
<< result->abstract()->ToString() << "\njoined_spec:\t"
|
||||
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
||||
auto value = analysis_cache_map_.find(conf);
|
||||
if (value == analysis_cache_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return value->second;
|
||||
}
|
||||
|
||||
std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(conf->node());
|
||||
|
@ -105,6 +73,7 @@ bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeCon
|
|||
}
|
||||
|
||||
AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) {
|
||||
StaticAnalysisException::Instance().ClearException();
|
||||
ConfigPtrList args_conf_list;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
|
@ -127,6 +96,11 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
|
|||
MS_EXCEPTION_IF_NULL(output_conf);
|
||||
result.inferred = output_conf->ObtainEvalResult();
|
||||
result.context = root_context;
|
||||
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
pybind11::gil_scoped_release release;
|
||||
AnalysisResultCacheMgr::GetInstance().Wait();
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -137,15 +111,35 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
|
|||
return eval->context();
|
||||
}
|
||||
|
||||
void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
|
||||
cache_mgr.SetValue(conf, result);
|
||||
|
||||
// Set intermediate abstract value.
|
||||
if (IsIntermediateAbstract(result->abstract())) {
|
||||
if (conf->node()->intermediate_abstract() == nullptr) {
|
||||
conf->node()->set_intermediate_abstract(result->abstract());
|
||||
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
|
||||
} else {
|
||||
auto old_spec = conf->node()->intermediate_abstract();
|
||||
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
|
||||
conf->node()->set_intermediate_abstract(joined_spec);
|
||||
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
|
||||
<< result->abstract()->ToString() << "\njoined_spec:\t"
|
||||
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
EvalResultPtr result = analysis_cache_.GetValue(conf);
|
||||
static AnalysisResultCacheMgr &cache_mgr = AnalysisResultCacheMgr::GetInstance();
|
||||
auto result = cache_mgr.GetValue(conf);
|
||||
if (result != nullptr) {
|
||||
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString()
|
||||
<< ", Value: " << result->abstract().get() << ", " << result->abstract()->ToString();
|
||||
return result;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
|
||||
result = Eval(conf);
|
||||
if (result == nullptr) {
|
||||
|
@ -187,8 +181,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
trace::TraceEvalCNodeLeave();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() << "(" << node->type_name()
|
||||
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
|
||||
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
<< "), fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph");
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
|
@ -280,6 +273,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
MS_LOG(DEBUG) << "EvalCNode eval Undetermined";
|
||||
return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
|
||||
if (func == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << maybe_func->ToString() << ".";
|
||||
|
@ -321,28 +315,28 @@ EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const Abs
|
|||
}
|
||||
|
||||
void AnalysisEngine::ClearEvaluatorCache() {
|
||||
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : evaluators_) {
|
||||
for (auto &element : evaluators_) {
|
||||
EvaluatorPtr evaluator = element.second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
|
||||
evaluator->evaluator_cache_map()->clear();
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
|
||||
evaluator->evaluator_cache_mgr()->Clear();
|
||||
}
|
||||
for (auto &element : prim_constructors_) {
|
||||
EvaluatorPtr evaluator = element.second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
|
||||
evaluator->evaluator_cache_map()->clear();
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
|
||||
evaluator->evaluator_cache_mgr()->Clear();
|
||||
}
|
||||
for (auto &element : prim_py_evaluators_) {
|
||||
EvaluatorPtr evaluator = element.second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
|
||||
evaluator->evaluator_cache_map()->clear();
|
||||
MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_mgr());
|
||||
evaluator->evaluator_cache_mgr()->Clear();
|
||||
}
|
||||
}
|
||||
|
||||
void AnalysisEngine::Clear() {
|
||||
analysis_cache_.Clear();
|
||||
AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
anfnode_config_map_.clear();
|
||||
eval_trace_.clear();
|
||||
evaluators_.clear();
|
||||
|
@ -517,6 +511,11 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
if (func->tracking_id() != nullptr) {
|
||||
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
|
||||
// protect the constructors
|
||||
static std::recursive_mutex constructors_mutex;
|
||||
// std::lock_guard<std::recursive_mutex> lock(constructors_mutex);
|
||||
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
|
||||
func->isa<abstract::FuncGraphAbstractClosure>()) {
|
||||
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
|
||||
|
@ -570,7 +569,16 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
MS_EXCEPTION_IF_NULL(eval);
|
||||
return eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
}
|
||||
#if !(defined _WIN32 || defined _WIN64)
|
||||
static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1");
|
||||
if (enable_singleThread) {
|
||||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||
} else {
|
||||
return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
|
||||
}
|
||||
#else
|
||||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||
#endif
|
||||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||
|
@ -623,17 +631,17 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
|
|||
for (auto u_eval : undetermined_evals) {
|
||||
MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
|
||||
auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
|
||||
auto &eval_cache = alternate_evaluator->evaluator_cache_map();
|
||||
auto eval_cache = alternate_evaluator->evaluator_cache_mgr();
|
||||
const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list);
|
||||
if ((!undetermined_evals.count(alt_eval_args)) &&
|
||||
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
|
||||
(eval_cache->find(args_spec_list) == eval_cache->end()))) {
|
||||
(((!continued_evals_.count(u_eval)) && (eval_cache->GetValue(args_spec_list) != nullptr)) ||
|
||||
(eval_cache->GetValue(args_spec_list) == nullptr))) {
|
||||
MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined.";
|
||||
has_undetermined = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_undetermined == false) {
|
||||
if (!has_undetermined) {
|
||||
MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
|
||||
*continue_flag = true;
|
||||
return latest_entry;
|
||||
|
@ -675,7 +683,7 @@ std::string JoinBranchesFailedInfo(const AbstractBasePtr &spec, const AbstractBa
|
|||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node) {
|
||||
if (out_specs.size() == 0) {
|
||||
if (out_specs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "There is an endless loop for evaluator.";
|
||||
}
|
||||
|
||||
|
@ -710,6 +718,135 @@ EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_
|
|||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
if (abstract->isa<AbstractSequeue>()) {
|
||||
auto elements = abstract->cast<AbstractSequeuePtr>()->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return item->isa<AbstractFunction>(); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
EvalResultPtr ExecEvaluetor(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list,
|
||||
AnfNodeConfigPtr out_conf, std::string caller, AsyncEvalResultPtr async_result_branch,
|
||||
AsyncEvalResultPtr async_result_main) {
|
||||
// TiggerToken_scoped_acquire tigger_token_acquire;
|
||||
py::gil_scoped_acquire pyGuard;
|
||||
|
||||
EvalResultPtr result = nullptr;
|
||||
try {
|
||||
AnalysisResultCacheMgr::UpdateCaller(caller);
|
||||
result = eval->Run(engine, args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
async_result_branch->JoinResult(result);
|
||||
async_result_main->JoinResult(result);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
|
||||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->abstract()->ToString();
|
||||
auto broadAbstract = result->abstract()->Broaden();
|
||||
auto broadEvalResult = std::make_shared<EvalResult>(broadAbstract, nullptr);
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadEvalResult);
|
||||
} catch (const std::exception &e) {
|
||||
std::ostringstream oss;
|
||||
trace::GetEvalStackInfo(oss);
|
||||
if (!oss.str().empty()) {
|
||||
MS_LOG(ERROR) << oss.str();
|
||||
}
|
||||
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>(oss.str()), out_conf->node());
|
||||
async_result_main->JoinResult(std::make_shared<EvalResult>(abstractErrPtr, nullptr));
|
||||
StaticAnalysisException::Instance().SetException();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
// TiggerToken_scoped_release tigger_token_release;
|
||||
pybind11::gil_scoped_release release;
|
||||
// Wait for the switch node to finish.
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
|
||||
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
|
||||
if (eval_result == nullptr) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : Init switch " << out_conf->ToString();
|
||||
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
|
||||
} else {
|
||||
if (eval_result->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out."
|
||||
<< "please check the code if there are recursive functions.";
|
||||
}
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
// Eval result of the branches and main.
|
||||
AsyncEvalResultPtr asyncResult0 = std::make_shared<AsyncEvalResult>();
|
||||
AsyncEvalResultPtr asyncResult1 = std::make_shared<AsyncEvalResult>();
|
||||
AsyncEvalResultPtr asyncResult_main = std::make_shared<AsyncEvalResult>();
|
||||
SetUndeterminedFlag(evaluators[0]);
|
||||
SetUndeterminedFlag(evaluators[1]);
|
||||
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
||||
auto future0 = std::async(std::launch::async, ExecEvaluetor, evaluators[0], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult0, asyncResult_main);
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
|
||||
auto future1 = std::async(std::launch::async, ExecEvaluetor, evaluators[1], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult1, asyncResult_main);
|
||||
|
||||
// Wait for async threads to finish.
|
||||
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
||||
<< " or " << evaluators[1]->ToString();
|
||||
auto branchResult = asyncResult_main->GetResult();
|
||||
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
|
||||
<< "please check the code if there are recursive functions.";
|
||||
}
|
||||
if (branchResult->isa<AbstractError>()) {
|
||||
MS_LOG(EXCEPTION) << "async " << out_conf->node()->ToString() << " threw exception.";
|
||||
}
|
||||
|
||||
AbstractBasePtrList out_specs;
|
||||
if (NeedWaitForTwoBranches(branchResult->abstract())) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[0]->ToString();
|
||||
auto result0 = asyncResult0->GetResult();
|
||||
if (result0->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << "is time out."
|
||||
<< " Please check the code if there is recursive function.";
|
||||
}
|
||||
out_specs.push_back(result0->abstract());
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[1]->ToString();
|
||||
auto result1 = asyncResult1->GetResult();
|
||||
if (result1->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << "is time out."
|
||||
<< " Please check the code if there is recursive function.";
|
||||
}
|
||||
out_specs.push_back(result1->abstract());
|
||||
} else {
|
||||
if (asyncResult0->TryGetResult()) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[0]->ToString()
|
||||
<< " value0=" << asyncResult0->GetResult()->abstract()->ToString();
|
||||
out_specs.push_back(asyncResult0->GetResult()->abstract());
|
||||
}
|
||||
|
||||
if (asyncResult1->TryGetResult()) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async . waiting for " << evaluators[1]->ToString()
|
||||
<< " value1=" << asyncResult1->GetResult()->abstract()->ToString();
|
||||
out_specs.push_back(asyncResult1->GetResult()->abstract());
|
||||
}
|
||||
}
|
||||
|
||||
return ProcessEvalResults(out_specs, out_conf->node());
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
|
@ -726,9 +863,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
for (auto eval : evaluators) {
|
||||
for (const auto &eval : evaluators) {
|
||||
SetUndeterminedFlag(eval);
|
||||
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
|
@ -757,11 +893,11 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
// Try to travel the latest undetermined.
|
||||
if (latest_entry != eval_trace_.rbegin()->evaluator_) {
|
||||
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
|
||||
auto latest_entry_eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(latest_entry_eval_result->abstract());
|
||||
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
|
||||
<< " return out_spec: " << latest_entry_eval_result->abstract()->ToString();
|
||||
return latest_entry_eval_result;
|
||||
<< " return out_spec: " << eval_result->abstract()->ToString();
|
||||
return eval_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -815,8 +951,9 @@ AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &cont
|
|||
if (value->isa<Primitive>()) {
|
||||
auto prim = value->cast<PrimitivePtr>();
|
||||
return MakeAbstractClosure(prim, anf_node);
|
||||
} else {
|
||||
return value->ToAbstract();
|
||||
}
|
||||
return value->ToAbstract();
|
||||
}
|
||||
|
||||
AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <mutex>
|
||||
|
||||
#ifdef DEBUG
|
||||
#include <stack>
|
||||
|
@ -154,19 +155,6 @@ class VirtualConfig : public Config {
|
|||
AbstractBasePtr abstract_;
|
||||
};
|
||||
|
||||
// AnalysisCache
|
||||
class AnalysisCache {
|
||||
public:
|
||||
AnalysisCache() = default;
|
||||
~AnalysisCache() = default;
|
||||
void Clear() { analysis_cache_map_.clear(); }
|
||||
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
|
||||
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
|
||||
|
||||
private:
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> analysis_cache_map_;
|
||||
};
|
||||
|
||||
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
|
||||
using AnfNodeConfigMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
|
@ -183,12 +171,38 @@ struct PartialAppHasher {
|
|||
return h1 ^ h2;
|
||||
}
|
||||
};
|
||||
|
||||
// Should compare Args based on value other than pointer;
|
||||
struct EvaluatorArgs {
|
||||
EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {}
|
||||
bool operator==(const EvaluatorArgs &other) const {
|
||||
if (evaluator_ != other.evaluator_) {
|
||||
return false;
|
||||
}
|
||||
if (AbstractBasePtrListDeepEqual(args_, other.args_)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool operator!=(const EvaluatorArgs &other) { return !(*this == other); }
|
||||
|
||||
EvaluatorPtr evaluator_;
|
||||
AbstractBasePtrList args_;
|
||||
};
|
||||
using EvalTraceRevIter = std::list<EvaluatorArgs>::reverse_iterator;
|
||||
struct EvaluatorArgsHasher {
|
||||
std::size_t operator()(const EvaluatorArgs &eval_args) const {
|
||||
return hash_combine(std::hash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_));
|
||||
}
|
||||
};
|
||||
struct EvaluatorArgsEqual {
|
||||
bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; }
|
||||
};
|
||||
|
||||
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||
public:
|
||||
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
|
||||
: analysis_cache_(AnalysisCache()),
|
||||
prim_constructors_(prim_evaluator_map),
|
||||
func_graph_manager_(func_graph_manager) {
|
||||
: prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {
|
||||
function_call_depth_ = 0;
|
||||
function_call_max_depth_ = 0;
|
||||
stack_frame_depth_ = 0;
|
||||
|
@ -202,11 +216,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
// func_graph: The func_graph to analyze.
|
||||
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
|
||||
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||
void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
analysis_cache_.set_value(conf, result);
|
||||
}
|
||||
void SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result);
|
||||
EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
|
||||
// Return the Evaluator for the given function.
|
||||
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||
|
@ -218,7 +228,6 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
|
||||
void Clear();
|
||||
void ClearEvaluatorCache();
|
||||
AnalysisCache &analysis_cache() { return analysis_cache_; }
|
||||
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
|
||||
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
|
||||
}
|
||||
|
@ -239,7 +248,6 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
|
||||
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
|
||||
|
||||
AnalysisCache analysis_cache_;
|
||||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
void ResetFunctionCallDepth() {
|
||||
|
@ -249,7 +257,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
void IncreaseFunctionCallDepth() {
|
||||
function_call_depth_++;
|
||||
if (function_call_max_depth_ < function_call_depth_) {
|
||||
function_call_max_depth_ = function_call_depth_;
|
||||
function_call_max_depth_ = function_call_depth_.load();
|
||||
}
|
||||
}
|
||||
void DecreaseFunctionCallDepth() {
|
||||
|
@ -268,7 +276,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
void IncreaseStackFrameDepth() {
|
||||
stack_frame_depth_++;
|
||||
if (stack_frame_max_depth_ < stack_frame_depth_) {
|
||||
stack_frame_max_depth_ = stack_frame_depth_;
|
||||
stack_frame_max_depth_ = stack_frame_depth_.load();
|
||||
}
|
||||
}
|
||||
void DecreaseStackFrameDepth() {
|
||||
|
@ -279,39 +287,10 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
}
|
||||
size_t stack_frame_depth() const { return stack_frame_depth_; }
|
||||
size_t stack_frame_max_depth() const { return stack_frame_max_depth_; }
|
||||
|
||||
void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf);
|
||||
|
||||
bool enable_recursive_eval() const { return enable_recursive_eval_; }
|
||||
|
||||
private:
|
||||
// Should compare Args based on value other than pointer;
|
||||
struct EvaluatorArgs {
|
||||
EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {}
|
||||
bool operator==(const EvaluatorArgs &other) const {
|
||||
if (evaluator_ != other.evaluator_) {
|
||||
return false;
|
||||
}
|
||||
if (AbstractBasePtrListDeepEqual(args_, other.args_)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool operator!=(const EvaluatorArgs &other) { return !(*this == other); }
|
||||
|
||||
EvaluatorPtr evaluator_;
|
||||
AbstractBasePtrList args_;
|
||||
};
|
||||
using EvalTraceRevIter = std::list<EvaluatorArgs>::reverse_iterator;
|
||||
struct EvaluatorArgsHasher {
|
||||
std::size_t operator()(const EvaluatorArgs &eval_args) const {
|
||||
return hash_combine(std::hash<EvaluatorPtr>{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_));
|
||||
}
|
||||
};
|
||||
struct EvaluatorArgsEqual {
|
||||
bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; }
|
||||
};
|
||||
|
||||
void SetUndeterminedFlag(const EvaluatorPtr &evaluator);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
|
@ -323,6 +302,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> evaluators_;
|
||||
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
|
||||
constructors_app_;
|
||||
|
||||
AnfNodeConfigMap anfnode_config_map_;
|
||||
// Use a list to trace multiple evaluators.
|
||||
std::list<EvaluatorArgs> eval_trace_;
|
||||
|
@ -337,15 +317,18 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
const ConfigPtrList &args_conf_list);
|
||||
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
EvalResultPtr ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
// Record current depth of function call stack, including `stack_frame_depth_`.
|
||||
size_t function_call_depth_;
|
||||
size_t function_call_max_depth_;
|
||||
std::atomic_long function_call_depth_;
|
||||
std::atomic_long function_call_max_depth_;
|
||||
|
||||
// Record current depth of stack frames call.
|
||||
size_t stack_frame_depth_;
|
||||
size_t stack_frame_max_depth_;
|
||||
std::atomic_long stack_frame_depth_;
|
||||
std::atomic_long stack_frame_max_depth_;
|
||||
|
||||
size_t forward_count_;
|
||||
std::atomic_long forward_count_;
|
||||
|
||||
bool enable_recursive_eval_;
|
||||
|
||||
|
|
|
@ -162,6 +162,10 @@ AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const {
|
|||
if (context->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) || config == kBroadenScalarParameterOnly) {
|
||||
return AbstractBase::Broaden(config);
|
||||
} else {
|
||||
auto type = this->BuildType()->type_id();
|
||||
if (type < kNumberTypeBegin || type > kNumberTypeEnd) {
|
||||
return AbstractBase::Broaden(config);
|
||||
}
|
||||
return Clone();
|
||||
}
|
||||
}
|
||||
|
@ -1085,6 +1089,27 @@ std::string AbstractNull::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
bool AbstractTimeOut::operator==(const AbstractTimeOut &) const { return true; }
|
||||
|
||||
bool AbstractTimeOut::operator==(const AbstractBase &other) const {
|
||||
if (&other == this) {
|
||||
return true;
|
||||
}
|
||||
if (other.isa<AbstractTimeOut>()) {
|
||||
auto other_none = static_cast<const AbstractTimeOut *>(&other);
|
||||
return *this == *other_none;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string AbstractTimeOut::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << "AbstractTimeOut "
|
||||
<< "(Value: Null)";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
|
||||
|
||||
bool AbstractEllipsis::operator==(const AbstractBase &other) const {
|
||||
|
|
|
@ -312,7 +312,7 @@ class AbstractTensor : public AbstractUndetermined {
|
|||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden(uint8_t config = 0) const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other);
|
||||
AbstractBasePtr Join(const AbstractBasePtr &other) override;
|
||||
bool operator==(const AbstractTensor &other) const;
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
std::string ToString() const override;
|
||||
|
@ -565,6 +565,21 @@ class AbstractNull : public AbstractBase {
|
|||
};
|
||||
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
|
||||
|
||||
// the timeout state value for variable, which means the variable is not assigned because it is timeout
|
||||
class AbstractTimeOut : public AbstractBase {
|
||||
public:
|
||||
AbstractTimeOut() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
|
||||
~AbstractTimeOut() override = default;
|
||||
MS_DECLARE_PARENT(AbstractTimeOut, AbstractBase)
|
||||
|
||||
TypePtr BuildType() const override { return std::make_shared<TypeNull>(); }
|
||||
bool operator==(const AbstractTimeOut &other) const;
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTimeOut>(); }
|
||||
std::string ToString() const override;
|
||||
};
|
||||
using AbstractTimeOutPtr = std::shared_ptr<AbstractTimeOut>;
|
||||
|
||||
class AbstractEllipsis : public AbstractBase {
|
||||
public:
|
||||
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<TypeEllipsis>()); }
|
||||
|
|
|
@ -230,7 +230,7 @@ DebugInfoPtr TraceManager::GetParseOrResolveDebugInfo() { return TraceManager::p
|
|||
|
||||
void TraceManager::ClearParseOrResolveDebugInfo() { TraceManager::parse_or_resolve_debug_info_ = nullptr; }
|
||||
|
||||
std::stack<TraceContextPtr> TraceManager::trace_context_stack_;
|
||||
thread_local std::stack<TraceContextPtr> TraceManager::trace_context_stack_;
|
||||
|
||||
DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr;
|
||||
thread_local DebugInfoPtr TraceManager::parse_or_resolve_debug_info_ = nullptr;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,8 +79,8 @@ class TraceManager {
|
|||
static void ClearParseOrResolveDebugInfo();
|
||||
static DebugInfoPtr GetParseOrResolveDebugInfo();
|
||||
|
||||
static std::stack<TraceContextPtr> trace_context_stack_;
|
||||
static DebugInfoPtr parse_or_resolve_debug_info_;
|
||||
thread_local static std::stack<TraceContextPtr> trace_context_stack_;
|
||||
thread_local static DebugInfoPtr parse_or_resolve_debug_info_;
|
||||
};
|
||||
|
||||
class TraceGuard {
|
||||
|
|
|
@ -58,6 +58,45 @@ class MsException {
|
|||
ExceptionListener *listener_{nullptr};
|
||||
std::exception_ptr exception_ptr_{nullptr};
|
||||
};
|
||||
|
||||
class StaticAnalysisException {
|
||||
public:
|
||||
static StaticAnalysisException &Instance() {
|
||||
static StaticAnalysisException instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void ClearException() { exception_ptr_ = nullptr; }
|
||||
|
||||
bool HasException() { return exception_ptr_ != nullptr; }
|
||||
|
||||
void SetException() {
|
||||
if (exception_ptr_ != nullptr) {
|
||||
return;
|
||||
}
|
||||
exception_ptr_ = std::current_exception();
|
||||
}
|
||||
|
||||
void SetAndRethrowException() {
|
||||
SetException();
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
|
||||
void CheckException() {
|
||||
if (exception_ptr_ != nullptr) {
|
||||
auto tmp_exception_ptr = exception_ptr_;
|
||||
exception_ptr_ = nullptr;
|
||||
std::rethrow_exception(tmp_exception_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
StaticAnalysisException() = default;
|
||||
~StaticAnalysisException() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(StaticAnalysisException)
|
||||
|
||||
std::exception_ptr exception_ptr_{nullptr};
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "common/common_test.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
|
@ -37,6 +38,12 @@ class TestAbstract : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestAbstract, TestParseDataClass) {
|
||||
// Check initialization before callback to Python.
|
||||
if (Py_IsInitialized() == 0) {
|
||||
Py_Initialize();
|
||||
}
|
||||
PyEval_InitThreads();
|
||||
|
||||
py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo");
|
||||
|
||||
ClassPtr cls_ptr = parse::ParseDataClass(fn);
|
||||
|
|
Loading…
Reference in New Issue