infer optimize using unlock cache

This commit is contained in:
lanzhineng 2021-06-25 20:36:28 +08:00
parent c986916d48
commit a1aac76105
2 changed files with 47 additions and 2 deletions

View File

@ -86,12 +86,55 @@ class MultiThreadCache {
CacheType cache_;
};
template <typename KeyType, typename ValueType, typename CacheType>
class NormalCache {
public:
using iterator = typename CacheType::iterator;
using const_iterator = typename CacheType::const_iterator;
ValueType get(const KeyType &key) {
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second;
}
return nullptr;
}
void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
void clear() { cache_.clear(); }
size_t size() { return cache_.size(); }
bool empty() { return size() == 0; }
std::string dump() {
std::ostringstream buf;
for (auto &item : cache_) {
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
}
return buf.str();
}
iterator begin() { return cache_.begin(); }
iterator end() { return cache_.end(); }
const_iterator begin() const { return cache_.cbegin(); }
const_iterator end() const { return cache_.cend(); }
const_iterator cbegin() const { return cache_.cbegin(); }
const_iterator cend() const { return cache_.cend(); }
private:
CacheType cache_;
};
class AsyncEvalResult;
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
using EvaluatorCacheMap =
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvalResultCache = MultiThreadCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
class AsyncEvalResult {
public:
@ -197,7 +240,7 @@ class AnalysisResultCacheMgr {
using AnalysisConfigResultMap =
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
using AnalysisConfigResultCache = MultiThreadCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }

View File

@ -580,6 +580,8 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
static std::mutex fg_lock;
std::lock_guard<std::mutex> infer_lock(fg_lock);
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) {
return;