forked from mindspore-Ecosystem/mindspore
infer optimize using unlock cache
This commit is contained in:
parent
c986916d48
commit
a1aac76105
|
@ -86,12 +86,55 @@ class MultiThreadCache {
|
|||
CacheType cache_;
|
||||
};
|
||||
|
||||
template <typename KeyType, typename ValueType, typename CacheType>
|
||||
class NormalCache {
|
||||
public:
|
||||
using iterator = typename CacheType::iterator;
|
||||
using const_iterator = typename CacheType::const_iterator;
|
||||
|
||||
ValueType get(const KeyType &key) {
|
||||
auto it = cache_.find(key);
|
||||
if (it != cache_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
|
||||
|
||||
void clear() { cache_.clear(); }
|
||||
|
||||
size_t size() { return cache_.size(); }
|
||||
|
||||
bool empty() { return size() == 0; }
|
||||
|
||||
std::string dump() {
|
||||
std::ostringstream buf;
|
||||
for (auto &item : cache_) {
|
||||
buf << "{" << item.first->ToString() << ":" << item.second->ToString() << "}" << std::endl;
|
||||
}
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
iterator begin() { return cache_.begin(); }
|
||||
iterator end() { return cache_.end(); }
|
||||
|
||||
const_iterator begin() const { return cache_.cbegin(); }
|
||||
const_iterator end() const { return cache_.cend(); }
|
||||
|
||||
const_iterator cbegin() const { return cache_.cbegin(); }
|
||||
const_iterator cend() const { return cache_.cend(); }
|
||||
|
||||
private:
|
||||
CacheType cache_;
|
||||
};
|
||||
|
||||
class AsyncEvalResult;
|
||||
using AsyncEvalResultPtr = std::shared_ptr<AsyncEvalResult>;
|
||||
|
||||
using EvaluatorCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvalResultCache = MultiThreadCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
||||
using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
|
||||
|
||||
class AsyncEvalResult {
|
||||
public:
|
||||
|
@ -197,7 +240,7 @@ class AnalysisResultCacheMgr {
|
|||
|
||||
using AnalysisConfigResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigResultCache = MultiThreadCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
|
||||
inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
|
||||
inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
|
||||
|
|
|
@ -580,6 +580,8 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||
static std::mutex fg_lock;
|
||||
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval == nullptr) {
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue