!49519 Refine memory allocator for cpu hash table

Merge pull request !49519 from zyli2020/persistent_storage
This commit is contained in:
i-robot 2023-02-28 10:57:21 +00:00 committed by Gitee
commit 2b8ad31223
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 22 additions and 18 deletions

View File

@ -37,11 +37,6 @@ CPUHashTable<Key, Value>::~CPUHashTable() {
template <typename Key, typename Value>
bool CPUHashTable<Key, Value>::Initialize() {
const uint32_t device_id = 0;
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kCPUDevice, device_id});
MS_EXCEPTION_IF_NULL(device_ctx_);
device_ctx_->Initialize();
value_size_ = value_dim_ * sizeof(Value);
return true;
}
@ -81,15 +76,13 @@ bool CPUHashTable<Key, Value>::Find(const Key *keys, size_t key_num, bool, Value
template <typename Key, typename Value>
bool CPUHashTable<Key, Value>::Insert(const Key *keys, size_t key_num, const Value *value, void *) {
std::unique_lock<std::shared_mutex> lock(mutex_);
MS_EXCEPTION_IF_NULL(device_ctx_);
MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
for (size_t i = 0; i < key_num; ++i) {
const auto &key = keys[i];
// The the key does not exist, a new value buffer should be allocated firstly.
if (values_.find(key) == values_.end()) {
auto value_addr = device_ctx_->device_res_manager_->AllocateMemory(value_size_);
auto value_addr = AllocateMemory(value_size_);
MS_EXCEPTION_IF_NULL(value_addr);
values_[key] = reinterpret_cast<Value *>(value_addr);
}
@ -111,8 +104,6 @@ bool CPUHashTable<Key, Value>::Insert(const Key *keys, size_t key_num, const Val
template <typename Key, typename Value>
bool CPUHashTable<Key, Value>::Erase(const Key *keys, size_t key_num, void *) {
std::unique_lock<std::shared_mutex> lock(mutex_);
MS_EXCEPTION_IF_NULL(device_ctx_);
MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
// Erase all the keys in the hash table.
for (size_t i = 0; i < key_num; ++i) {
@ -123,9 +114,10 @@ bool CPUHashTable<Key, Value>::Erase(const Key *keys, size_t key_num, void *) {
// Return the memory of value to the pool.
MS_EXCEPTION_IF_NULL(value);
device_ctx_->device_res_manager_->FreeMemory(value);
FreeMemory(value);
} else {
MS_LOG(ERROR) << "The key: " << key << " does not exist in the hash table.";
return false;
}
}
return true;
@ -214,20 +206,29 @@ bool CPUHashTable<Key, Value>::is_dirty() const {
template <typename Key, typename Value>
bool CPUHashTable<Key, Value>::Clear() {
std::unique_lock<std::shared_mutex> lock(mutex_);
MS_EXCEPTION_IF_NULL(device_ctx_);
MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
// Return all the memory of values in hash table to the memory pool.
for (auto iter = values_.begin(); iter != values_.end(); iter++) {
auto key = iter->first;
auto value = iter->second;
if (value != nullptr) {
device_ctx_->device_res_manager_->FreeMemory(value);
FreeMemory(value);
}
values_.erase(key);
}
return true;
}
template <typename Key, typename Value>
void *CPUHashTable<Key, Value>::AllocateMemory(size_t size) const {
return CPUMemoryPool::GetInstance().AllocTensorMem(size, false);
}
template <typename Key, typename Value>
void CPUHashTable<Key, Value>::FreeMemory(void *ptr) const {
MS_EXCEPTION_IF_NULL(ptr);
CPUMemoryPool::GetInstance().FreeTensorMem(ptr);
}
template class CPUHashTable<int32_t, float>;
template class CPUHashTable<int64_t, float>;
} // namespace cpu

View File

@ -19,8 +19,8 @@
#include <shared_mutex>
#include <unordered_map>
#include "runtime/hardware/device_context.h"
#include "runtime/device/hash_table.h"
#include "plugin/device/cpu/hal/hardware/cpu_memory_pool.h"
namespace mindspore {
namespace device {
@ -64,6 +64,12 @@ class CPUHashTable : public HashTable<Key, Value> {
bool Clear() override;
private:
// Allocate host memory from dynamic memory pool.
void *AllocateMemory(size_t size) const;
// Free host memory to dynamic memory pool.
void FreeMemory(void *ptr) const;
// The key-value style elements stored in this hash table.
std::unordered_map<Key, Value *> values_;
@ -78,9 +84,6 @@ class CPUHashTable : public HashTable<Key, Value> {
// The flag records whether the elements of the hash table have changed since the last export, true means that there
// has been a change.
bool is_dirty_{true};
// The device context is used to allocate host side memory from the pool for values in the hash table.
DeviceContext *device_ctx_{nullptr};
};
} // namespace cpu
} // namespace device