forked from mindspore-Ecosystem/mindspore
infer optimize using unlock cache
This commit is contained in:
parent
c986916d48
commit
a1aac76105
|
@ -86,12 +86,55 @@ class MultiThreadCache {
|
||||||
CacheType cache_;
|
CacheType cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename KeyType, typename ValueType, typename CacheType>
|
||||||
|
class NormalCache {
|
||||||
|
public:
|
||||||
|
using iterator = typename CacheType::iterator;
|
||||||
|
using const_iterator = typename CacheType::const_iterator;
|
||||||
|
|
||||||
|
ValueType get(const KeyType &key) {
|
||||||
|
auto it = cache_.find(key);
|
||||||
|
if (it != cache_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
|
||||||
|
|
||||||
|
void clear() { cache_.clear(); }
|
||||||
|
|
||||||
|
size_t size() { return cache_.size(); }
|
||||||
|
|
||||||
|
bool empty() { return size() == 0; }
|
||||||
|
|
||||||
|
std::string dump() {
|
||||||
|
std::ostringstream buf;
|
||||||
|
for (auto &item : cache_) {
|
||||||
|
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
|
||||||
|
}
|
||||||
|
return buf.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator begin() { return cache_.begin(); }
|
||||||
|
iterator end() { return cache_.end(); }
|
||||||
|
|
||||||
|
const_iterator begin() const { return cache_.cbegin(); }
|
||||||
|
const_iterator end() const { return cache_.cend(); }
|
||||||
|
|
||||||
|
const_iterator cbegin() const { return cache_.cbegin(); }
|
||||||
|
const_iterator cend() const { return cache_.cend(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
CacheType cache_;
|
||||||
|
};
|
||||||
|
|
||||||
class AsyncEvalResult;
|
class AsyncEvalResult;
|
||||||
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
|
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
|
||||||
|
|
||||||
using EvaluatorCacheMap =
|
using EvaluatorCacheMap =
|
||||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||||
using EvalResultCache = MultiThreadCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
||||||
|
|
||||||
class AsyncEvalResult {
|
class AsyncEvalResult {
|
||||||
public:
|
public:
|
||||||
|
@ -197,7 +240,7 @@ class AnalysisResultCacheMgr {
|
||||||
|
|
||||||
using AnalysisConfigResultMap =
|
using AnalysisConfigResultMap =
|
||||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||||
using AnalysisConfigResultCache = MultiThreadCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||||
|
|
||||||
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
|
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
|
||||||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||||
|
|
|
@ -580,6 +580,8 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||||
|
static std::mutex fg_lock;
|
||||||
|
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||||
if (fg_eval == nullptr) {
|
if (fg_eval == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Reference in New Issue