infer optimize using unlock cache

This commit is contained in:
lanzhineng 2021-06-25 20:36:28 +08:00
parent c986916d48
commit a1aac76105
2 changed files with 47 additions and 2 deletions

View File

@ -86,12 +86,55 @@ class MultiThreadCache {
CacheType cache_; CacheType cache_;
}; };
template <typename KeyType, typename ValueType, typename CacheType>
class NormalCache {
public:
using iterator = typename CacheType::iterator;
using const_iterator = typename CacheType::const_iterator;
ValueType get(const KeyType &key) {
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second;
}
return nullptr;
}
void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
void clear() { cache_.clear(); }
size_t size() { return cache_.size(); }
bool empty() { return size() == 0; }
std::string dump() {
std::ostringstream buf;
for (auto &item : cache_) {
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
}
return buf.str();
}
iterator begin() { return cache_.begin(); }
iterator end() { return cache_.end(); }
const_iterator begin() const { return cache_.cbegin(); }
const_iterator end() const { return cache_.cend(); }
const_iterator cbegin() const { return cache_.cbegin(); }
const_iterator cend() const { return cache_.cend(); }
private:
CacheType cache_;
};
class AsyncEvalResult; class AsyncEvalResult;
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>; using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
using EvaluatorCacheMap = using EvaluatorCacheMap =
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvalResultCache = MultiThreadCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>; using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
class AsyncEvalResult { class AsyncEvalResult {
public: public:
@ -197,7 +240,7 @@ class AnalysisResultCacheMgr {
using AnalysisConfigResultMap = using AnalysisConfigResultMap =
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
using AnalysisConfigResultCache = MultiThreadCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>; using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); } inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); } inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }

View File

@ -580,6 +580,8 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
} }
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
static std::mutex fg_lock;
std::lock_guard<std::mutex> infer_lock(fg_lock);
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>(); auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) { if (fg_eval == nullptr) {
return; return;