infer_opt:fixed eval path
This commit is contained in:
parent
662f8cb377
commit
05db2d6e56
|
@ -24,49 +24,39 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
EvalResultPtr AsyncEvalResult::TryGetResult(int ms) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (ms == 0) {
|
||||
return result_;
|
||||
}
|
||||
auto time = std::chrono::microseconds(ms);
|
||||
// Wait for ms.
|
||||
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
return result_;
|
||||
}
|
||||
HealthPointMgr HealthPointMgr::instance_;
|
||||
|
||||
EvalResultPtr AsyncEvalResult::GetResult() {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (result_ != nullptr) {
|
||||
return result_;
|
||||
void HealthPointMgr::HandleException() {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
for (auto &item : asyncAbstractList_) {
|
||||
item->SetRunable();
|
||||
}
|
||||
}
|
||||
void HealthPointMgr::SetNextRunable() {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
if (asyncAbstractList_.empty()) {
|
||||
MS_LOG(DEBUG) << "The Health List is empty. ";
|
||||
return;
|
||||
}
|
||||
// Check if enter endless loop
|
||||
HealthPointScopedDrop health_point_check;
|
||||
auto time = std::chrono::seconds(kInferTimeout);
|
||||
auto cond = condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
if (cond) {
|
||||
return result_;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Timeout!";
|
||||
return std::make_shared<EvalResult>(std::make_shared<AbstractTimeOut>(), nullptr);
|
||||
auto it = std::find_if(asyncAbstractList_.begin(), asyncAbstractList_.end(),
|
||||
[](const auto &item) { return item->HasResult(); });
|
||||
if (it == asyncAbstractList_.end()) {
|
||||
// Enter endless loop if there is not ready result.
|
||||
MS_LOG(EXCEPTION) << "Enter endless loop. Please check the code. point = "
|
||||
<< " point:" << HealthPointMgr::GetInstance().point()
|
||||
<< " Called times : " << asyncAbstractList_.front()->count();
|
||||
}
|
||||
asyncAbstractList_.insert(asyncAbstractList_.end(), asyncAbstractList_.begin(), it);
|
||||
asyncAbstractList_.erase(asyncAbstractList_.begin(), it);
|
||||
|
||||
MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The Health Point is " << point_
|
||||
<< " Called times : " << asyncAbstractList_.front()->count();
|
||||
asyncAbstractList_.front()->SetRunable();
|
||||
asyncAbstractList_.pop_front();
|
||||
}
|
||||
|
||||
std::string AsyncEvalResult::ToString() {
|
||||
std::ostringstream buffer;
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
buffer << (result_ == nullptr ? "NOT SET" : result_->abstract()->ToString());
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
void AsyncEvalResult::JoinResult(const EvalResultPtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
}
|
||||
condition_var_.notify_all();
|
||||
}
|
||||
AnalysisResultCacheMgr AnalysisResultCacheMgr::instance_;
|
||||
|
||||
void AnalysisResultCacheMgr::Clear() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
|
@ -75,30 +65,6 @@ void AnalysisResultCacheMgr::Clear() {
|
|||
todo_.clear();
|
||||
}
|
||||
|
||||
AnalysisResultCacheMgr &AnalysisResultCacheMgr::GetInstance() {
|
||||
static AnalysisResultCacheMgr instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::DumpCache(const std::string &filename) {
|
||||
auto path = pipeline::GetSaveGraphsPathName(Common::AddId(filename, ".cache"));
|
||||
auto realpath = Common::GetRealPath(path);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed. path=" << path;
|
||||
return;
|
||||
}
|
||||
ChangeFileMode(realpath.value(), S_IRWXU);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!";
|
||||
return;
|
||||
}
|
||||
fout << cache_.dump();
|
||||
fout.close();
|
||||
// Set file mode to read only by user
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
}
|
||||
|
||||
thread_local static std::string local_threadid;
|
||||
void AnalysisResultCacheMgr::UpdateCaller(const std::string &caller) {
|
||||
std::ostringstream buffer;
|
||||
|
@ -121,23 +87,35 @@ void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
|
|||
|
||||
void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
if (async_eval_result == nullptr) {
|
||||
async_eval_result = std::make_shared<AsyncEvalResult>();
|
||||
async_eval_result = std::make_shared<AsyncAbstract>();
|
||||
switch_cache_.set(conf, async_eval_result);
|
||||
}
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
// don't call lock_.lock(). switch_cache is protected. and it waits for result.
|
||||
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
// Conf has been visited and set value.
|
||||
if (async_eval_result != nullptr) {
|
||||
// Maybe blocked for waiting. AsyncEvalResult maybe null, if time out.
|
||||
return async_eval_result->TryGetResult();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
// don't call lock_.lock(). switch_cache is protected. and it waits for result.
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
// Conf has been visited and set value.
|
||||
if (async_eval_result != nullptr) {
|
||||
// Add to schedule
|
||||
HealthPointMgr::GetInstance().PushBack(async_eval_result);
|
||||
// Maybe blocked for waiting. AsyncAbstract maybe null, if time out.
|
||||
auto result = async_eval_result->GetResult();
|
||||
if (result == nullptr) {
|
||||
result = std::make_shared<EvalResult>(std::make_shared<AbstractTimeOut>(), nullptr);
|
||||
MS_LOG(ERROR) << "AsyncEvalResult for NodeConfig " << conf->node()->ToString() << " is nullptr, maybe timeout.";
|
||||
result = std::make_shared<AbstractTimeOut>();
|
||||
MS_LOG(ERROR) << "AsyncAbstract for NodeConfig " << conf->node()->ToString() << " is nullptr, maybe timeout.";
|
||||
MS_LOG(ERROR) << "detail:" << conf->ToString();
|
||||
}
|
||||
return result;
|
||||
|
@ -145,26 +123,26 @@ EvalResultPtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &con
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr arg) {
|
||||
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr arg) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
if (arg == nullptr || arg->abstract() == nullptr) {
|
||||
if (arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
AsyncEvalResultPtr async_eval_result = switch_cache_.get(conf);
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
if (async_eval_result == nullptr) {
|
||||
async_eval_result = std::make_shared<AsyncEvalResult>();
|
||||
async_eval_result = std::make_shared<AsyncAbstract>();
|
||||
async_eval_result->JoinResult(arg);
|
||||
switch_cache_.set(conf, async_eval_result);
|
||||
} else {
|
||||
auto ab1 = async_eval_result->TryGetResult();
|
||||
AbstractBasePtrList absList;
|
||||
if (ab1 != nullptr) {
|
||||
absList.push_back(arg->abstract());
|
||||
absList.push_back(ab1->abstract());
|
||||
absList.push_back(arg);
|
||||
absList.push_back(ab1);
|
||||
// Join two branches's result
|
||||
auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
|
||||
async_eval_result->JoinResult(joined_result);
|
||||
async_eval_result->JoinResult(joined_result->abstract());
|
||||
if (!(*joined_result == *ab1)) {
|
||||
PushTodo(conf);
|
||||
}
|
||||
|
@ -179,9 +157,9 @@ void AnalysisResultCacheMgr::Todo() {
|
|||
while (!todo_.empty()) {
|
||||
AnfNodeConfigPtr conf = todo_.front();
|
||||
todo_.pop_front();
|
||||
if (!(*GetValue(conf)->abstract() == *GetSwitchValue(conf)->abstract())) {
|
||||
if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) {
|
||||
MS_LOG(WARNING) << " Switch Value is not eq. "
|
||||
<< " switchCache: " << GetSwitchValue(conf)->abstract()->ToString()
|
||||
<< " switchCache: " << TryGetSwitchValue(conf)->ToString()
|
||||
<< " globleCache: " << GetValue(conf)->abstract()->ToString() << "\t\tConf: " << conf->ToString();
|
||||
}
|
||||
}
|
||||
|
@ -189,6 +167,8 @@ void AnalysisResultCacheMgr::Todo() {
|
|||
|
||||
void AnalysisResultCacheMgr::Wait() {
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
// Check all the async to finish.
|
||||
HealthPointScopedDrop hpCheck;
|
||||
while (true) {
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
lock_.lock();
|
||||
|
@ -208,11 +188,6 @@ void AnalysisResultCacheMgr::Wait() {
|
|||
}
|
||||
}
|
||||
|
||||
HealthPointMgr &HealthPointMgr::GetInstance() {
|
||||
static HealthPointMgr instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "(";
|
||||
|
|
|
@ -33,42 +33,52 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
constexpr size_t kInferTimeout = 1800; // 60*30 30min, next pr will change the solution of endless.
|
||||
constexpr size_t kInferTryTimeout = 3; // 3 microsecond.
|
||||
|
||||
class AsyncAbstract;
|
||||
using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
|
||||
class HealthPointMgr {
|
||||
public:
|
||||
~HealthPointMgr() = default;
|
||||
HealthPointMgr(const HealthPointMgr &) = delete;
|
||||
HealthPointMgr &operator=(const HealthPointMgr &) = delete;
|
||||
static HealthPointMgr &GetInstance();
|
||||
static HealthPointMgr &GetInstance() { return instance_; }
|
||||
void SetNextRunable();
|
||||
|
||||
void DropPoint() {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto time = std::chrono::microseconds(kInferTryTimeout);
|
||||
auto cond = condition_var_.wait_for(lock, time, [this] { return point_ > 1; });
|
||||
if (cond) {
|
||||
--point_;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Enter endless loop. Please check the code. ";
|
||||
void CheckPoint() {
|
||||
MS_LOG(DEBUG) << "The Health Point is " << point_;
|
||||
if (point_ == 0) {
|
||||
SetNextRunable();
|
||||
} else if (point_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "There is something wrong.";
|
||||
}
|
||||
}
|
||||
|
||||
void DropPoint() {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
--point_;
|
||||
CheckPoint();
|
||||
}
|
||||
|
||||
void HandleException();
|
||||
|
||||
void AddPoint() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
++point_;
|
||||
}
|
||||
condition_var_.notify_all();
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
++point_;
|
||||
}
|
||||
|
||||
int point() { return point_; }
|
||||
|
||||
void PushBack(const AsyncAbstractPtr &base) {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
asyncAbstractList_.push_back(base);
|
||||
}
|
||||
|
||||
private:
|
||||
HealthPointMgr() = default;
|
||||
static HealthPointMgr instance_;
|
||||
int point_{1};
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
std::recursive_mutex lock_;
|
||||
std::list<AsyncAbstractPtr> asyncAbstractList_;
|
||||
};
|
||||
|
||||
class HealthPointScopedDrop {
|
||||
|
@ -171,70 +181,65 @@ class NormalCache {
|
|||
CacheType cache_;
|
||||
};
|
||||
|
||||
class AsyncEvalResult;
|
||||
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
|
||||
|
||||
using EvaluatorCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
||||
|
||||
class AsyncEvalResult {
|
||||
class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
||||
public:
|
||||
AsyncEvalResult() = default;
|
||||
~AsyncEvalResult() = default;
|
||||
AsyncAbstract() = default;
|
||||
~AsyncAbstract() = default;
|
||||
// Wait
|
||||
EvalResultPtr GetResult();
|
||||
// Not wait
|
||||
EvalResultPtr TryGetResult(int ms = 0);
|
||||
void JoinResult(const EvalResultPtr &result);
|
||||
std::string ToString();
|
||||
|
||||
private:
|
||||
EvalResultPtr result_{nullptr};
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
class AsyncResult {
|
||||
public:
|
||||
AsyncResult() = default;
|
||||
~AsyncResult() = default;
|
||||
// Wait
|
||||
Type GetResult() {
|
||||
AbstractBasePtr GetResult() {
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (result_ != nullptr) {
|
||||
return result_;
|
||||
}
|
||||
auto time = std::chrono::seconds(kInferTimeout);
|
||||
// Check if enter endless loop
|
||||
HealthPointScopedDrop health_point_check;
|
||||
auto cond = condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
if (cond) {
|
||||
return result_;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Timeout!";
|
||||
return nullptr;
|
||||
while (true) {
|
||||
++count_;
|
||||
// The point should be dropped if it can't run. It will be added when it can run.
|
||||
bool hasDropPoint = false;
|
||||
if (!runable_) {
|
||||
HealthPointMgr::GetInstance().DropPoint();
|
||||
hasDropPoint = true;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << this << " ranable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
condition_var_.wait(lock, [this] { return runable_; });
|
||||
MS_LOG(DEBUG) << this << " continue ranable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
runable_ = false;
|
||||
if (result_ != nullptr) {
|
||||
if (hasDropPoint) {
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
}
|
||||
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0);
|
||||
return result_;
|
||||
}
|
||||
// Push to list
|
||||
HealthPointMgr::GetInstance().PushBack(shared_from_this());
|
||||
if (hasDropPoint) {
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
}
|
||||
// Notify the next asyncAbastract to run.
|
||||
HealthPointMgr::GetInstance().SetNextRunable();
|
||||
MS_LOG(DEBUG) << this << " SetNextRunable "
|
||||
<< " ranable: " << runable_ << " result: " << (result_ ? result_.get() : 0)
|
||||
<< " point:" << HealthPointMgr::GetInstance().point();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
void SetRunable() {
|
||||
MS_LOG(DEBUG) << this << " Runable.";
|
||||
runable_ = true;
|
||||
condition_var_.notify_one();
|
||||
}
|
||||
int count() { return count_; }
|
||||
|
||||
bool HasResult() { return result_ != nullptr; }
|
||||
// Not wait
|
||||
Type TryGetResult(int ms = 0) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (ms == 0) {
|
||||
return result_;
|
||||
}
|
||||
auto time = std::chrono::microseconds(ms);
|
||||
// Wait for ms.
|
||||
(void)condition_var_.wait_for(lock, time, [this] { return result_ != nullptr; });
|
||||
AbstractBasePtr TryGetResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return result_;
|
||||
}
|
||||
void JoinResult(const Type &result) {
|
||||
void JoinResult(const AbstractBasePtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
}
|
||||
condition_var_.notify_all();
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
}
|
||||
std::string ToString() {
|
||||
std::ostringstream buffer;
|
||||
|
@ -244,13 +249,16 @@ class AsyncResult {
|
|||
}
|
||||
|
||||
private:
|
||||
Type result_{nullptr};
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
bool runable_{false};
|
||||
int count_{0};
|
||||
AbstractBasePtr result_{nullptr};
|
||||
};
|
||||
|
||||
using AsyncAbstractResult = AsyncResult<AbstractBasePtr>;
|
||||
using AsyncAbstractResultPtr = std::shared_ptr<AsyncAbstractResult>;
|
||||
using EvaluatorCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
||||
|
||||
class EvaluatorCacheMgr {
|
||||
public:
|
||||
|
@ -273,23 +281,10 @@ class AnalysisResultCacheMgr {
|
|||
~AnalysisResultCacheMgr() = default;
|
||||
AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
|
||||
AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
|
||||
static AnalysisResultCacheMgr &GetInstance();
|
||||
static AnalysisResultCacheMgr &GetInstance() { return instance_; }
|
||||
void Clear();
|
||||
|
||||
using AnalysisConfigAsyncResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, AsyncEvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigAsyncResultCache =
|
||||
MultiThreadCache<AnfNodeConfigPtr, AsyncEvalResultPtr, AnalysisConfigAsyncResultMap>;
|
||||
|
||||
using AnalysisConfigResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
|
||||
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
|
||||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||
|
||||
// Dump all the conf and result
|
||||
void DumpCache(const std::string &filename);
|
||||
// Wait for async Eval(conf) to finish.
|
||||
void Wait();
|
||||
void PushTowait(std::future<void> &&future0, std::future<void> &&future1);
|
||||
|
@ -297,13 +292,23 @@ class AnalysisResultCacheMgr {
|
|||
void Todo();
|
||||
static void UpdateCaller(const std::string &caller);
|
||||
static std::string &GetThreadid();
|
||||
|
||||
void InitSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr GetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
void SetSwitchValue(const AnfNodeConfigPtr &conf, const EvalResultPtr vale);
|
||||
AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr vale);
|
||||
|
||||
private:
|
||||
using AnalysisConfigAsyncResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, AsyncAbstractPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigAsyncResultCache =
|
||||
MultiThreadCache<AnfNodeConfigPtr, AsyncAbstractPtr, AnalysisConfigAsyncResultMap>;
|
||||
|
||||
using AnalysisConfigResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
|
||||
AnalysisResultCacheMgr() = default;
|
||||
static AnalysisResultCacheMgr instance_;
|
||||
std::mutex lock_;
|
||||
std::list<std::future<void>> waiting_;
|
||||
std::mutex todo_lock_;
|
||||
|
|
|
@ -628,12 +628,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
#endif
|
||||
}
|
||||
|
||||
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
|
||||
static std::mutex fg_lock;
|
||||
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval == nullptr) {
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto fg = fg_eval->func_graph();
|
||||
|
@ -644,16 +644,14 @@ bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
|
|||
if (fg_parent != nullptr) {
|
||||
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
|
||||
return true;
|
||||
return;
|
||||
} else if (possible_parent_fg != nullptr) {
|
||||
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString() << " for fg: " << fg->ToString();
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cannot find parent for fg: " << fg->ToString();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
||||
|
@ -791,20 +789,17 @@ bool NeedWaitForTwoBranches(const AbstractBasePtr &abstract) {
|
|||
}
|
||||
|
||||
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
||||
std::string caller, AsyncAbstractResultPtr async_result_branch,
|
||||
AsyncAbstractResultPtr async_result_main, bool first, AsyncAbstractResultPtr async_first_Result) {
|
||||
std::string caller, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
|
||||
AsyncAbstractPtr async_run_flag) {
|
||||
AnalysisResultCacheMgr::UpdateCaller(caller);
|
||||
// Wait for the first fg to run
|
||||
if (!first) {
|
||||
(void)async_first_Result->GetResult();
|
||||
}
|
||||
try {
|
||||
// Wait for Signal to run
|
||||
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " waiting.";
|
||||
(void)async_run_flag->GetResult();
|
||||
MS_LOG(DEBUG) << async_run_flag.get() << " " << eval->ToString() << " running.";
|
||||
|
||||
// Acquire GIL
|
||||
py::gil_scoped_acquire pyGuard;
|
||||
// Notify the second fg to go
|
||||
if (first) {
|
||||
async_first_Result->JoinResult(std::make_shared<AbstractScalar>(1));
|
||||
}
|
||||
trace::ClearTraceStack();
|
||||
auto result = eval->Run(engine, args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
@ -813,13 +808,14 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
// Broaden the result of switch(c,t,f)()
|
||||
auto broadAbstract = result->abstract()->Broaden();
|
||||
// Let main thread to continue.
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf,
|
||||
std::make_shared<EvalResult>(broadAbstract, nullptr));
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract);
|
||||
async_result_branch->JoinResult(broadAbstract);
|
||||
async_result_main->JoinResult(broadAbstract);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
|
||||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
// Decrease infer thread.
|
||||
HealthPointMgr::GetInstance().DropPoint();
|
||||
} catch (const std::exception &e) {
|
||||
std::ostringstream oss;
|
||||
oss << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
|
@ -828,13 +824,11 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
MS_LOG(ERROR) << oss.str();
|
||||
}
|
||||
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>(oss.str()), out_conf->node());
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf,
|
||||
std::make_shared<EvalResult>(abstractErrPtr, nullptr));
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
|
||||
async_result_main->JoinResult(abstractErrPtr);
|
||||
StaticAnalysisException::Instance().SetException();
|
||||
HealthPointMgr::GetInstance().HandleException();
|
||||
}
|
||||
// Decrease infer thread.
|
||||
HealthPointMgr::GetInstance().DropPoint();
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
|
||||
|
@ -843,54 +837,64 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
// Release GIL;
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
|
||||
// Wait for the switch node to finish.
|
||||
// Wait for the last switch node to finish.
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
|
||||
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
|
||||
if (eval_result == nullptr) {
|
||||
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
|
||||
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
|
||||
} else {
|
||||
if (eval_result->abstract()->isa<AbstractTimeOut>()) {
|
||||
if (eval_result->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << out_conf->node()->ToString() << " time out."
|
||||
<< " Please check the code if there are recursive functions.";
|
||||
}
|
||||
if (eval_result->abstract()->isa<AbstractError>()) {
|
||||
if (eval_result->isa<AbstractError>()) {
|
||||
MS_LOG(DEBUG) << "Eval " << out_conf->node()->ToString() << " threw exception.";
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
}
|
||||
return eval_result;
|
||||
return std::make_shared<EvalResult>(eval_result, nullptr);
|
||||
}
|
||||
|
||||
// Eval result of the branches and main.
|
||||
AsyncAbstractResultPtr asyncResult_main = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncResult0 = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractPtr asyncResult_main = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr asyncResult0 = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr asyncResult1 = std::make_shared<AsyncAbstract>();
|
||||
|
||||
// Control which thread to run.
|
||||
AsyncAbstractPtr asyncRun0 = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr asyncRun1 = std::make_shared<AsyncAbstract>();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
bool firstRun = !SetUndeterminedFlag(evaluators[0], possible_parent_fg);
|
||||
(void)SetUndeterminedFlag(evaluators[1], possible_parent_fg);
|
||||
SetUndeterminedFlag(evaluators[0], possible_parent_fg);
|
||||
SetUndeterminedFlag(evaluators[1], possible_parent_fg);
|
||||
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future0 = std::async(std::launch::async, ExecEvaluator, evaluators[0], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult0, asyncResult_main, firstRun, asyncFirstRunResult);
|
||||
out_conf, threadId, asyncResult0, asyncResult_main, asyncRun0);
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[1]->ToString();
|
||||
// Add point to infer thread
|
||||
HealthPointMgr::GetInstance().AddPoint();
|
||||
auto future1 = std::async(std::launch::async, ExecEvaluator, evaluators[1], shared_from_this(), args_conf_list,
|
||||
out_conf, threadId, asyncResult1, asyncResult_main, !firstRun, asyncFirstRunResult);
|
||||
out_conf, threadId, asyncResult1, asyncResult_main, asyncRun1);
|
||||
|
||||
// Wait for async threads to finish.
|
||||
AnalysisResultCacheMgr::GetInstance().PushTowait(std::move(future0), std::move(future1));
|
||||
// Push to list of running loop
|
||||
asyncRun0->JoinResult(std::make_shared<AbstractScalar>(0));
|
||||
asyncRun1->JoinResult(std::make_shared<AbstractScalar>(0));
|
||||
// Run order
|
||||
HealthPointMgr::GetInstance().PushBack(asyncRun0); // First order
|
||||
HealthPointMgr::GetInstance().PushBack(asyncRun1); // Second order
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
||||
<< " or " << evaluators[1]->ToString();
|
||||
HealthPointMgr::GetInstance().PushBack(asyncResult_main); // Third order
|
||||
auto branchResult = asyncResult_main->GetResult();
|
||||
if (branchResult == nullptr || branchResult->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Can't finish " << evaluators[0]->ToString() << " or " << evaluators[1]->ToString()
|
||||
|
@ -906,6 +910,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
AbstractBasePtrList out_specs;
|
||||
if (NeedWaitForTwoBranches(branchResult)) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString();
|
||||
// The asyncRun0 will eval asyncResult0
|
||||
HealthPointMgr::GetInstance().PushBack(asyncResult0);
|
||||
auto result0 = asyncResult0->GetResult();
|
||||
if (result0 == nullptr || result0->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[0]->ToString() << " is time out."
|
||||
|
@ -914,6 +920,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
out_specs.push_back(result0);
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString();
|
||||
// The asyncRun1 will eval asyncResult1
|
||||
HealthPointMgr::GetInstance().PushBack(asyncResult1);
|
||||
auto result1 = asyncResult1->GetResult();
|
||||
if (result1 == nullptr || result1->isa<AbstractTimeOut>()) {
|
||||
MS_LOG(EXCEPTION) << "Eval " << evaluators[1]->ToString() << " is time out."
|
||||
|
@ -921,15 +929,24 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
}
|
||||
out_specs.push_back(result1);
|
||||
} else {
|
||||
if (asyncResult0->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
|
||||
// Next time to get the result of branches.
|
||||
HealthPointMgr::GetInstance().PushBack(asyncResult_main);
|
||||
(void)asyncResult_main->GetResult();
|
||||
|
||||
// Don't use GetResult
|
||||
auto value0 = asyncResult0->TryGetResult();
|
||||
if (value0) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[0]->ToString()
|
||||
<< " value0=" << asyncResult0->GetResult()->ToString();
|
||||
out_specs.push_back(asyncResult0->GetResult());
|
||||
<< " value0=" << value0->ToString();
|
||||
out_specs.push_back(value0);
|
||||
}
|
||||
if (asyncResult1->TryGetResult((HealthPointMgr::GetInstance().point() - 1) * kInferTryTimeout)) {
|
||||
|
||||
// Don't use GetResult
|
||||
auto value1 = asyncResult1->TryGetResult();
|
||||
if (value1) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[1]->ToString()
|
||||
<< " value1=" << asyncResult1->GetResult()->ToString();
|
||||
out_specs.push_back(asyncResult1->GetResult());
|
||||
<< " value1=" << value1->ToString();
|
||||
out_specs.push_back(value1);
|
||||
}
|
||||
}
|
||||
return ProcessEvalResults(out_specs, out_conf->node());
|
||||
|
|
|
@ -266,7 +266,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
|
||||
void SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
bool *continue_flag);
|
||||
|
|
Loading…
Reference in New Issue