forked from mindspore-Ecosystem/mindspore
!20025 codex & support windows
Merge pull request !20025 from lanzhineng/r1.3
This commit is contained in:
commit
7703c631b3
|
@ -26,7 +26,7 @@ namespace abstract {
|
|||
HealthPointMgr HealthPointMgr::instance_;
|
||||
|
||||
void HealthPointMgr::Clear() {
|
||||
MS_LOG(DEBUG) << " Point = " << point_;
|
||||
MS_LOG(DEBUG) << " Point: " << point_;
|
||||
point_ = 1;
|
||||
}
|
||||
|
||||
|
@ -43,12 +43,12 @@ void HealthPointMgr::HandleException() {
|
|||
// Free all the locks. Let all the threads continue to run.
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
for (auto &item : asyncAbstractList_) {
|
||||
item->SetRunable();
|
||||
item->SetRunnable();
|
||||
}
|
||||
asyncAbstractList_.clear();
|
||||
}
|
||||
|
||||
void HealthPointMgr::SetNextRunable() {
|
||||
void HealthPointMgr::SetNextRunnable() {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
if (asyncAbstractList_.empty()) {
|
||||
MS_LOG(DEBUG) << "The Health List is empty. ";
|
||||
|
@ -67,7 +67,7 @@ void HealthPointMgr::SetNextRunable() {
|
|||
|
||||
MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The Health Point is " << point_
|
||||
<< " Called times : " << asyncAbstractList_.front()->count();
|
||||
asyncAbstractList_.front()->SetRunable();
|
||||
asyncAbstractList_.front()->SetRunnable();
|
||||
asyncAbstractList_.pop_front();
|
||||
}
|
||||
|
||||
|
@ -144,13 +144,13 @@ AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &c
|
|||
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
if (arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
|
||||
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr.";
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
|
||||
if (async_eval_result == nullptr) {
|
||||
async_eval_result = std::make_shared<AsyncAbstract>();
|
||||
async_eval_result->JoinResult(arg);
|
||||
async_eval_result->SetResult(arg);
|
||||
switch_cache_.set(conf, async_eval_result);
|
||||
} else {
|
||||
auto ab1 = async_eval_result->TryGetResult();
|
||||
|
@ -160,12 +160,12 @@ void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const
|
|||
absList.push_back(ab1);
|
||||
// Join two branches's result
|
||||
auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
|
||||
async_eval_result->JoinResult(joined_result->abstract());
|
||||
async_eval_result->SetResult(joined_result->abstract());
|
||||
if (!(*joined_result == *ab1)) {
|
||||
PushTodo(conf);
|
||||
}
|
||||
} else {
|
||||
async_eval_result->JoinResult(arg);
|
||||
async_eval_result->SetResult(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -176,11 +176,11 @@ void AnalysisResultCacheMgr::Todo() {
|
|||
AnfNodeConfigPtr conf = todo_.front();
|
||||
todo_.pop_front();
|
||||
if (GetValue(conf) == nullptr) {
|
||||
MS_LOG(WARNING) << conf->node()->ToString() << " not in globleCache";
|
||||
MS_LOG(WARNING) << conf->node()->ToString() << " not in globle cache.";
|
||||
continue;
|
||||
}
|
||||
if (TryGetSwitchValue(conf) == nullptr) {
|
||||
MS_LOG(WARNING) << conf->node()->ToString() << " not in switchCache";
|
||||
MS_LOG(WARNING) << conf->node()->ToString() << " not in switch cache.";
|
||||
continue;
|
||||
}
|
||||
if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) {
|
||||
|
|
|
@ -43,25 +43,25 @@ class HealthPointMgr {
|
|||
HealthPointMgr &operator=(const HealthPointMgr &) = delete;
|
||||
static HealthPointMgr &GetInstance() { return instance_; }
|
||||
void Clear();
|
||||
void SetNextRunable();
|
||||
void SetNextRunnable();
|
||||
void HandleException();
|
||||
|
||||
void CheckPoint() {
|
||||
MS_LOG(DEBUG) << "The Health Point is " << point_;
|
||||
if (point_ == 0) {
|
||||
SetNextRunable();
|
||||
SetNextRunnable();
|
||||
} else if (point_ < 0) {
|
||||
MS_LOG(WARNING) << "There is something wrong. point = " << point_;
|
||||
MS_LOG(WARNING) << "There is something wrong. point: " << point_;
|
||||
}
|
||||
}
|
||||
|
||||
void DropPoint() {
|
||||
void DecrPoint() {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
--point_;
|
||||
CheckPoint();
|
||||
}
|
||||
|
||||
void AddPoint() {
|
||||
void IncrPoint() {
|
||||
std::lock_guard<std::recursive_mutex> lock(lock_);
|
||||
++point_;
|
||||
}
|
||||
|
@ -83,8 +83,8 @@ class HealthPointMgr {
|
|||
|
||||
class HealthPointScopedDrop {
|
||||
public:
|
||||
HealthPointScopedDrop() { HealthPointMgr::GetInstance().DropPoint(); }
|
||||
~HealthPointScopedDrop() { HealthPointMgr::GetInstance().AddPoint(); }
|
||||
HealthPointScopedDrop() { HealthPointMgr::GetInstance().DecrPoint(); }
|
||||
~HealthPointScopedDrop() { HealthPointMgr::GetInstance().IncrPoint(); }
|
||||
};
|
||||
|
||||
template <typename KeyType, typename ValueType, typename CacheType>
|
||||
|
@ -119,7 +119,7 @@ class MultiThreadCache {
|
|||
std::string dump() {
|
||||
std::ostringstream buf;
|
||||
for (auto &item : cache_) {
|
||||
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
|
||||
buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
|
||||
}
|
||||
return buf.str();
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ class NormalCache {
|
|||
using iterator = typename CacheType::iterator;
|
||||
using const_iterator = typename CacheType::const_iterator;
|
||||
|
||||
ValueType get(const KeyType &key) {
|
||||
ValueType get(const KeyType &key) const {
|
||||
auto it = cache_.find(key);
|
||||
if (it != cache_.end()) {
|
||||
return it->second;
|
||||
|
@ -156,14 +156,14 @@ class NormalCache {
|
|||
|
||||
void clear() { cache_.clear(); }
|
||||
|
||||
size_t size() { return cache_.size(); }
|
||||
size_t size() const { return cache_.size(); }
|
||||
|
||||
bool empty() { return size() == 0; }
|
||||
bool empty() const { return size() == 0; }
|
||||
|
||||
std::string dump() {
|
||||
std::string dump() const {
|
||||
std::ostringstream buf;
|
||||
for (auto &item : cache_) {
|
||||
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
|
||||
buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
|
||||
}
|
||||
return buf.str();
|
||||
}
|
||||
|
@ -187,44 +187,42 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
~AsyncAbstract() = default;
|
||||
// Wait
|
||||
AbstractBasePtr GetResult() {
|
||||
static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance();
|
||||
static StaticAnalysisException &exceptionMgr = StaticAnalysisException::Instance();
|
||||
exceptionMgr.CheckException();
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
while (true) {
|
||||
++count_;
|
||||
// The point should be dropped if it can't run. It will be added when it can run.
|
||||
bool hasDropPoint = false;
|
||||
if (!runable_) {
|
||||
healthPointMgr.DropPoint();
|
||||
hasDropPoint = true;
|
||||
bool hasDecrPoint = false;
|
||||
if (!runnable_) {
|
||||
HealthPointMgr::GetInstance().DecrPoint();
|
||||
hasDecrPoint = true;
|
||||
}
|
||||
MS_LOG(DEBUG) << this << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
condition_var_.wait(lock, [this] { return runable_; });
|
||||
if (hasDropPoint) {
|
||||
healthPointMgr.AddPoint();
|
||||
MS_LOG(DEBUG) << this << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
condition_var_.wait(lock, [this] { return runnable_; });
|
||||
if (hasDecrPoint) {
|
||||
HealthPointMgr::GetInstance().IncrPoint();
|
||||
}
|
||||
MS_LOG(DEBUG) << this << " continue runable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
|
||||
|
||||
exceptionMgr.CheckException();
|
||||
runable_ = false;
|
||||
StaticAnalysisException::Instance().CheckException();
|
||||
runnable_ = false;
|
||||
if (result_ != nullptr) {
|
||||
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0);
|
||||
return result_;
|
||||
}
|
||||
// Push to list
|
||||
healthPointMgr.Add2Schedule(shared_from_this());
|
||||
HealthPointMgr::GetInstance().Add2Schedule(shared_from_this());
|
||||
// Notify the next asyncAbastract to run.
|
||||
healthPointMgr.SetNextRunable();
|
||||
MS_LOG(DEBUG) << this << " SetNextRunable "
|
||||
<< " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0)
|
||||
<< " point:" << healthPointMgr.point();
|
||||
HealthPointMgr::GetInstance().SetNextRunnable();
|
||||
MS_LOG(DEBUG) << this << " SetNextRunnable "
|
||||
<< " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0)
|
||||
<< " point: " << HealthPointMgr::GetInstance().point();
|
||||
}
|
||||
}
|
||||
|
||||
void SetRunable() {
|
||||
MS_LOG(DEBUG) << this << " Runable.";
|
||||
runable_ = true;
|
||||
void SetRunnable() {
|
||||
MS_LOG(DEBUG) << this << " runnable.";
|
||||
runnable_ = true;
|
||||
condition_var_.notify_one();
|
||||
}
|
||||
int count() const { return count_; }
|
||||
|
@ -235,7 +233,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return result_;
|
||||
}
|
||||
void JoinResult(const AbstractBasePtr &result) {
|
||||
void SetResult(const AbstractBasePtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
|
@ -250,7 +248,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
private:
|
||||
std::mutex lock_;
|
||||
std::condition_variable condition_var_;
|
||||
bool runable_{false};
|
||||
bool runnable_{false};
|
||||
int count_{0};
|
||||
AbstractBasePtr result_{nullptr};
|
||||
};
|
||||
|
@ -265,7 +263,7 @@ class EvaluatorCacheMgr {
|
|||
~EvaluatorCacheMgr() = default;
|
||||
|
||||
void Clear() { eval_result_cache_.clear(); }
|
||||
EvalResultCache &GetCache() { return eval_result_cache_; }
|
||||
const EvalResultCache &GetCache() { return eval_result_cache_; }
|
||||
EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
|
||||
void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
|
||||
size_t GetSize() { return eval_result_cache_.size(); }
|
||||
|
|
|
@ -590,7 +590,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
}
|
||||
MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
|
||||
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
|
||||
auto joined_eval_result = origin_eval_cache.get(joined_argvals);
|
||||
const auto joined_eval_result = origin_eval_cache.get(joined_argvals);
|
||||
if (joined_eval_result != nullptr) {
|
||||
MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
|
||||
|
||||
|
@ -741,7 +741,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
DumpEvaluatorCache(evaluator_cache_mgr, argvals);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(GetEvalCache(eval));
|
||||
EvalResultCache &choices = GetEvalCache(eval)->GetCache();
|
||||
const EvalResultCache &choices = GetEvalCache(eval)->GetCache();
|
||||
if (choices.get(argvals) != nullptr) {
|
||||
*result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract());
|
||||
return kSpecializeSuccess;
|
||||
|
|
|
@ -619,16 +619,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
MS_EXCEPTION_IF_NULL(eval);
|
||||
return eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
}
|
||||
#if !(defined _WIN32 || defined _WIN64)
|
||||
static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1");
|
||||
if (enable_singleThread) {
|
||||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||
} else {
|
||||
return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
|
||||
}
|
||||
#else
|
||||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||
#endif
|
||||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
|
||||
|
@ -816,10 +812,10 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
auto broadAbstract = result->abstract()->Broaden();
|
||||
// Notify the thread of waiting for switch node and the main thread to continue.
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract);
|
||||
async_result_branch->JoinResult(broadAbstract);
|
||||
async_result_main->JoinResult(broadAbstract);
|
||||
async_result_branch->SetResult(broadAbstract);
|
||||
async_result_main->SetResult(broadAbstract);
|
||||
// Health Point will be drop when thread exits.
|
||||
HealthPointMgr::GetInstance().DropPoint();
|
||||
HealthPointMgr::GetInstance().DecrPoint();
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
|
||||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
|
@ -827,7 +823,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
MS_LOG(WARNING) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
|
||||
async_result_main->JoinResult(abstractErrPtr);
|
||||
async_result_main->SetResult(abstractErrPtr);
|
||||
HealthPointMgr::GetInstance().HandleException();
|
||||
}
|
||||
}
|
||||
|
@ -837,18 +833,14 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
const ConfigPtrList &args_conf_list) {
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance();
|
||||
static AnalysisResultCacheMgr &resultCacheMgr = AnalysisResultCacheMgr::GetInstance();
|
||||
|
||||
// Release GIL for C++
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
|
||||
// Wait for the last switch node to finish.
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
|
||||
auto eval_result = resultCacheMgr.GetSwitchValue(out_conf);
|
||||
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
|
||||
if (eval_result == nullptr) {
|
||||
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
|
||||
resultCacheMgr.InitSwitchValue(out_conf);
|
||||
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
|
||||
} else {
|
||||
return std::make_shared<EvalResult>(eval_result, nullptr);
|
||||
}
|
||||
|
@ -866,21 +858,21 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
// Control the order to run.
|
||||
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
|
||||
// Add point to the async thread.
|
||||
healthPointMgr.AddPoint();
|
||||
HealthPointMgr::GetInstance().IncrPoint();
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
|
||||
auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
|
||||
threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
|
||||
// Wait for async threads to finish.
|
||||
resultCacheMgr.PushToWait(std::move(future));
|
||||
AnalysisResultCacheMgr::GetInstance().PushToWait(std::move(future));
|
||||
// Push to list of running loop
|
||||
asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
|
||||
healthPointMgr.Add2Schedule(asyncRunOrder); // Activate order
|
||||
asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order
|
||||
branchAsyncResults.emplace_back(std::move(branchAsyncResult));
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
|
||||
<< " or " << evaluators[1]->ToString() << "...";
|
||||
healthPointMgr.Add2Schedule(asyncResult_main); // Third order
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
|
||||
auto firstResult = asyncResult_main->GetResult();
|
||||
MS_EXCEPTION_IF_NULL(firstResult);
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
|
||||
|
@ -891,21 +883,21 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
if (NeedWaitForBranches(firstResult)) {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
|
||||
healthPointMgr.Add2Schedule(branchAsyncResults[i]);
|
||||
HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]);
|
||||
auto result = branchAsyncResults[i]->GetResult();
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
out_specs.push_back(result);
|
||||
}
|
||||
} else {
|
||||
// Give one more chance to wait for the result of the branches.
|
||||
healthPointMgr.Add2Schedule(asyncResult_main);
|
||||
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
|
||||
(void)asyncResult_main->GetResult();
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
// Not wait to get the result of branch.
|
||||
auto result = branchAsyncResults[i]->TryGetResult();
|
||||
if (result) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
|
||||
<< " result =" << result->ToString();
|
||||
<< " result: " << result->ToString();
|
||||
out_specs.push_back(result);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue