!20025 codex & support windows

Merge pull request !20025 from lanzhineng/r1.3
This commit is contained in:
i-robot 2021-07-12 13:15:54 +00:00 committed by Gitee
commit 7703c631b3
4 changed files with 62 additions and 72 deletions

View File

@ -26,7 +26,7 @@ namespace abstract {
HealthPointMgr HealthPointMgr::instance_;
void HealthPointMgr::Clear() {
MS_LOG(DEBUG) << " Point = " << point_;
MS_LOG(DEBUG) << " Point: " << point_;
point_ = 1;
}
@ -43,12 +43,12 @@ void HealthPointMgr::HandleException() {
// Free all the locks. Let all the threads continue to run.
std::lock_guard<std::recursive_mutex> lock(lock_);
for (auto &item : asyncAbstractList_) {
item->SetRunable();
item->SetRunnable();
}
asyncAbstractList_.clear();
}
void HealthPointMgr::SetNextRunable() {
void HealthPointMgr::SetNextRunnable() {
std::lock_guard<std::recursive_mutex> lock(lock_);
if (asyncAbstractList_.empty()) {
MS_LOG(DEBUG) << "The Health List is empty. ";
@ -67,7 +67,7 @@ void HealthPointMgr::SetNextRunable() {
MS_LOG(DEBUG) << asyncAbstractList_.front().get() << " The Health Point is " << point_
<< " Called times : " << asyncAbstractList_.front()->count();
asyncAbstractList_.front()->SetRunable();
asyncAbstractList_.front()->SetRunnable();
asyncAbstractList_.pop_front();
}
@ -144,13 +144,13 @@ AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &c
void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
MS_EXCEPTION_IF_NULL(conf);
if (arg == nullptr) {
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr.";
}
std::lock_guard<std::mutex> lock(lock_);
AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
if (async_eval_result == nullptr) {
async_eval_result = std::make_shared<AsyncAbstract>();
async_eval_result->JoinResult(arg);
async_eval_result->SetResult(arg);
switch_cache_.set(conf, async_eval_result);
} else {
auto ab1 = async_eval_result->TryGetResult();
@ -160,12 +160,12 @@ void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const
absList.push_back(ab1);
// Join two branches's result
auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
async_eval_result->JoinResult(joined_result->abstract());
async_eval_result->SetResult(joined_result->abstract());
if (!(*joined_result == *ab1)) {
PushTodo(conf);
}
} else {
async_eval_result->JoinResult(arg);
async_eval_result->SetResult(arg);
}
}
}
@ -176,11 +176,11 @@ void AnalysisResultCacheMgr::Todo() {
AnfNodeConfigPtr conf = todo_.front();
todo_.pop_front();
if (GetValue(conf) == nullptr) {
MS_LOG(WARNING) << conf->node()->ToString() << " not in globleCache";
MS_LOG(WARNING) << conf->node()->ToString() << " not in globle cache.";
continue;
}
if (TryGetSwitchValue(conf) == nullptr) {
MS_LOG(WARNING) << conf->node()->ToString() << " not in switchCache";
MS_LOG(WARNING) << conf->node()->ToString() << " not in switch cache.";
continue;
}
if (!(*GetValue(conf)->abstract() == *TryGetSwitchValue(conf))) {

View File

@ -43,25 +43,25 @@ class HealthPointMgr {
HealthPointMgr &operator=(const HealthPointMgr &) = delete;
static HealthPointMgr &GetInstance() { return instance_; }
void Clear();
void SetNextRunable();
void SetNextRunnable();
void HandleException();
void CheckPoint() {
MS_LOG(DEBUG) << "The Health Point is " << point_;
if (point_ == 0) {
SetNextRunable();
SetNextRunnable();
} else if (point_ < 0) {
MS_LOG(WARNING) << "There is something wrong. point = " << point_;
MS_LOG(WARNING) << "There is something wrong. point: " << point_;
}
}
void DropPoint() {
void DecrPoint() {
std::lock_guard<std::recursive_mutex> lock(lock_);
--point_;
CheckPoint();
}
void AddPoint() {
void IncrPoint() {
std::lock_guard<std::recursive_mutex> lock(lock_);
++point_;
}
@ -83,8 +83,8 @@ class HealthPointMgr {
class HealthPointScopedDrop {
public:
HealthPointScopedDrop() { HealthPointMgr::GetInstance().DropPoint(); }
~HealthPointScopedDrop() { HealthPointMgr::GetInstance().AddPoint(); }
HealthPointScopedDrop() { HealthPointMgr::GetInstance().DecrPoint(); }
~HealthPointScopedDrop() { HealthPointMgr::GetInstance().IncrPoint(); }
};
template <typename KeyType, typename ValueType, typename CacheType>
@ -119,7 +119,7 @@ class MultiThreadCache {
std::string dump() {
std::ostringstream buf;
for (auto &item : cache_) {
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
}
return buf.str();
}
@ -144,7 +144,7 @@ class NormalCache {
using iterator = typename CacheType::iterator;
using const_iterator = typename CacheType::const_iterator;
ValueType get(const KeyType &key) {
ValueType get(const KeyType &key) const {
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second;
@ -156,14 +156,14 @@ class NormalCache {
void clear() { cache_.clear(); }
size_t size() { return cache_.size(); }
size_t size() const { return cache_.size(); }
bool empty() { return size() == 0; }
bool empty() const { return size() == 0; }
std::string dump() {
std::string dump() const {
std::ostringstream buf;
for (auto &item : cache_) {
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
}
return buf.str();
}
@ -187,44 +187,42 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
~AsyncAbstract() = default;
// Wait
AbstractBasePtr GetResult() {
static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance();
static StaticAnalysisException &exceptionMgr = StaticAnalysisException::Instance();
exceptionMgr.CheckException();
StaticAnalysisException::Instance().CheckException();
std::unique_lock<std::mutex> lock(lock_);
while (true) {
++count_;
// The point should be dropped if it can't run. It will be added when it can run.
bool hasDropPoint = false;
if (!runable_) {
healthPointMgr.DropPoint();
hasDropPoint = true;
bool hasDecrPoint = false;
if (!runnable_) {
HealthPointMgr::GetInstance().DecrPoint();
hasDecrPoint = true;
}
MS_LOG(DEBUG) << this << " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
condition_var_.wait(lock, [this] { return runable_; });
if (hasDropPoint) {
healthPointMgr.AddPoint();
MS_LOG(DEBUG) << this << " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
condition_var_.wait(lock, [this] { return runnable_; });
if (hasDecrPoint) {
HealthPointMgr::GetInstance().IncrPoint();
}
MS_LOG(DEBUG) << this << " continue runable: " << runable_ << " result: " << (result_ ? result_.get() : 0);
MS_LOG(DEBUG) << this << " continue runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0);
exceptionMgr.CheckException();
runable_ = false;
StaticAnalysisException::Instance().CheckException();
runnable_ = false;
if (result_ != nullptr) {
MS_LOG(DEBUG) << this << " Return result: " << (result_ ? result_.get() : 0);
return result_;
}
// Push to list
healthPointMgr.Add2Schedule(shared_from_this());
HealthPointMgr::GetInstance().Add2Schedule(shared_from_this());
// Notify the next asyncAbastract to run.
healthPointMgr.SetNextRunable();
MS_LOG(DEBUG) << this << " SetNextRunable "
<< " runable: " << runable_ << " result: " << (result_ ? result_.get() : 0)
<< " point:" << healthPointMgr.point();
HealthPointMgr::GetInstance().SetNextRunnable();
MS_LOG(DEBUG) << this << " SetNextRunnable "
<< " runnable: " << runnable_ << " result: " << (result_ ? result_.get() : 0)
<< " point: " << HealthPointMgr::GetInstance().point();
}
}
void SetRunable() {
MS_LOG(DEBUG) << this << " Runable.";
runable_ = true;
void SetRunnable() {
MS_LOG(DEBUG) << this << " runnable.";
runnable_ = true;
condition_var_.notify_one();
}
int count() const { return count_; }
@ -235,7 +233,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
std::lock_guard<std::mutex> lock(lock_);
return result_;
}
void JoinResult(const AbstractBasePtr &result) {
void SetResult(const AbstractBasePtr &result) {
MS_EXCEPTION_IF_NULL(result);
std::lock_guard<std::mutex> lock(lock_);
result_ = result;
@ -250,7 +248,7 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
private:
std::mutex lock_;
std::condition_variable condition_var_;
bool runable_{false};
bool runnable_{false};
int count_{0};
AbstractBasePtr result_{nullptr};
};
@ -265,7 +263,7 @@ class EvaluatorCacheMgr {
~EvaluatorCacheMgr() = default;
void Clear() { eval_result_cache_.clear(); }
EvalResultCache &GetCache() { return eval_result_cache_; }
const EvalResultCache &GetCache() { return eval_result_cache_; }
EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
size_t GetSize() { return eval_result_cache_.size(); }

View File

@ -590,7 +590,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
}
MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
auto joined_eval_result = origin_eval_cache.get(joined_argvals);
const auto joined_eval_result = origin_eval_cache.get(joined_argvals);
if (joined_eval_result != nullptr) {
MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
@ -741,7 +741,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
DumpEvaluatorCache(evaluator_cache_mgr, argvals);
MS_EXCEPTION_IF_NULL(GetEvalCache(eval));
EvalResultCache &choices = GetEvalCache(eval)->GetCache();
const EvalResultCache &choices = GetEvalCache(eval)->GetCache();
if (choices.get(argvals) != nullptr) {
*result = std::make_pair(argvals, GetEvalCache(eval)->GetValue(argvals)->abstract());
return kSpecializeSuccess;

View File

@ -619,16 +619,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
MS_EXCEPTION_IF_NULL(eval);
return eval->Run(shared_from_this(), args_conf_list, out_conf);
}
#if !(defined _WIN32 || defined _WIN64)
static bool enable_singleThread = (common::GetEnv("ENV_SINGLE_EVAL") == "1");
if (enable_singleThread) {
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
} else {
return ExecuteMultipleEvaluatorsMultiThread(evaluators, out_conf, args_conf_list);
}
#else
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
#endif
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
@ -816,10 +812,10 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
auto broadAbstract = result->abstract()->Broaden();
// Notify the thread of waiting for switch node and the main thread to continue.
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, broadAbstract);
async_result_branch->JoinResult(broadAbstract);
async_result_main->JoinResult(broadAbstract);
async_result_branch->SetResult(broadAbstract);
async_result_main->SetResult(broadAbstract);
// Health Point will be drop when thread exits.
HealthPointMgr::GetInstance().DropPoint();
HealthPointMgr::GetInstance().DecrPoint();
MS_LOG(DEBUG) << GetInferThread() << "async :" << eval->ToString()
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString();
@ -827,7 +823,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
MS_LOG(WARNING) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
auto abstractErrPtr = std::make_shared<AbstractError>(std::make_shared<StringImm>("Exception"), out_conf->node());
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, abstractErrPtr);
async_result_main->JoinResult(abstractErrPtr);
async_result_main->SetResult(abstractErrPtr);
HealthPointMgr::GetInstance().HandleException();
}
}
@ -837,18 +833,14 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
const ConfigPtrList &args_conf_list) {
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
static HealthPointMgr &healthPointMgr = HealthPointMgr::GetInstance();
static AnalysisResultCacheMgr &resultCacheMgr = AnalysisResultCacheMgr::GetInstance();
// Release GIL for C++
py::gil_scoped_release infer_gil_release;
// Wait for the last switch node to finish.
MS_LOG(DEBUG) << GetInferThread() << "async : entry switch " << out_conf->ToString();
auto eval_result = resultCacheMgr.GetSwitchValue(out_conf);
auto eval_result = AnalysisResultCacheMgr::GetInstance().GetSwitchValue(out_conf);
if (eval_result == nullptr) {
MS_LOG(INFO) << GetInferThread() << "async : Init switch " << out_conf->node()->ToString();
resultCacheMgr.InitSwitchValue(out_conf);
AnalysisResultCacheMgr::GetInstance().InitSwitchValue(out_conf);
} else {
return std::make_shared<EvalResult>(eval_result, nullptr);
}
@ -866,21 +858,21 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
// Control the order to run.
AsyncAbstractPtr asyncRunOrder = std::make_shared<AsyncAbstract>();
// Add point to the async thread.
healthPointMgr.AddPoint();
HealthPointMgr::GetInstance().IncrPoint();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluator->ToString();
auto future = std::async(std::launch::async, ExecEvaluator, evaluator, shared_from_this(), args_conf_list, out_conf,
threadId, branchAsyncResult, asyncResult_main, asyncRunOrder);
// Wait for async threads to finish.
resultCacheMgr.PushToWait(std::move(future));
AnalysisResultCacheMgr::GetInstance().PushToWait(std::move(future));
// Push to list of running loop
asyncRunOrder->JoinResult(std::make_shared<AbstractScalar>(1));
healthPointMgr.Add2Schedule(asyncRunOrder); // Activate order
asyncRunOrder->SetResult(std::make_shared<AbstractScalar>(1));
HealthPointMgr::GetInstance().Add2Schedule(asyncRunOrder); // Activate order
branchAsyncResults.emplace_back(std::move(branchAsyncResult));
}
MS_LOG(DEBUG) << GetInferThread() << "async : wait for one of async to finish. " << evaluators[0]->ToString()
<< " or " << evaluators[1]->ToString() << "...";
healthPointMgr.Add2Schedule(asyncResult_main); // Third order
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main); // Third order
auto firstResult = asyncResult_main->GetResult();
MS_EXCEPTION_IF_NULL(firstResult);
MS_LOG(DEBUG) << GetInferThread() << "async main thread result of " << out_conf->node()->ToString() << " = "
@ -891,21 +883,21 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
if (NeedWaitForBranches(firstResult)) {
for (size_t i = 0; i < len; ++i) {
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
healthPointMgr.Add2Schedule(branchAsyncResults[i]);
HealthPointMgr::GetInstance().Add2Schedule(branchAsyncResults[i]);
auto result = branchAsyncResults[i]->GetResult();
MS_EXCEPTION_IF_NULL(result);
out_specs.push_back(result);
}
} else {
// Give one more chance to wait for the result of the branches.
healthPointMgr.Add2Schedule(asyncResult_main);
HealthPointMgr::GetInstance().Add2Schedule(asyncResult_main);
(void)asyncResult_main->GetResult();
for (size_t i = 0; i < len; ++i) {
// Not wait to get the result of branch.
auto result = branchAsyncResults[i]->TryGetResult();
if (result) {
MS_LOG(DEBUG) << GetInferThread() << "async get " << evaluators[i]->ToString()
<< " result =" << result->ToString();
<< " result: " << result->ToString();
out_specs.push_back(result);
}
}