!49519 Refine memory allocator for cpu hash table
Merge pull request !49519 from zyli2020/persistent_storage
This commit is contained in:
commit
2b8ad31223
|
@ -37,11 +37,6 @@ CPUHashTable<Key, Value>::~CPUHashTable() {
|
|||
|
||||
template <typename Key, typename Value>
|
||||
bool CPUHashTable<Key, Value>::Initialize() {
|
||||
const uint32_t device_id = 0;
|
||||
device_ctx_ = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kCPUDevice, device_id});
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_);
|
||||
device_ctx_->Initialize();
|
||||
|
||||
value_size_ = value_dim_ * sizeof(Value);
|
||||
return true;
|
||||
}
|
||||
|
@ -81,15 +76,13 @@ bool CPUHashTable<Key, Value>::Find(const Key *keys, size_t key_num, bool, Value
|
|||
template <typename Key, typename Value>
|
||||
bool CPUHashTable<Key, Value>::Insert(const Key *keys, size_t key_num, const Value *value, void *) {
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
|
||||
|
||||
for (size_t i = 0; i < key_num; ++i) {
|
||||
const auto &key = keys[i];
|
||||
|
||||
// The the key does not exist, a new value buffer should be allocated firstly.
|
||||
if (values_.find(key) == values_.end()) {
|
||||
auto value_addr = device_ctx_->device_res_manager_->AllocateMemory(value_size_);
|
||||
auto value_addr = AllocateMemory(value_size_);
|
||||
MS_EXCEPTION_IF_NULL(value_addr);
|
||||
values_[key] = reinterpret_cast<Value *>(value_addr);
|
||||
}
|
||||
|
@ -111,8 +104,6 @@ bool CPUHashTable<Key, Value>::Insert(const Key *keys, size_t key_num, const Val
|
|||
template <typename Key, typename Value>
|
||||
bool CPUHashTable<Key, Value>::Erase(const Key *keys, size_t key_num, void *) {
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
|
||||
|
||||
// Erase all the keys in the hash table.
|
||||
for (size_t i = 0; i < key_num; ++i) {
|
||||
|
@ -123,9 +114,10 @@ bool CPUHashTable<Key, Value>::Erase(const Key *keys, size_t key_num, void *) {
|
|||
|
||||
// Return the memory of value to the pool.
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
device_ctx_->device_res_manager_->FreeMemory(value);
|
||||
FreeMemory(value);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The key: " << key << " does not exist in the hash table.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -214,20 +206,29 @@ bool CPUHashTable<Key, Value>::is_dirty() const {
|
|||
template <typename Key, typename Value>
|
||||
bool CPUHashTable<Key, Value>::Clear() {
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_);
|
||||
MS_EXCEPTION_IF_NULL(device_ctx_->device_res_manager_);
|
||||
// Return all the memory of values in hash table to the memory pool.
|
||||
for (auto iter = values_.begin(); iter != values_.end(); iter++) {
|
||||
auto key = iter->first;
|
||||
auto value = iter->second;
|
||||
if (value != nullptr) {
|
||||
device_ctx_->device_res_manager_->FreeMemory(value);
|
||||
FreeMemory(value);
|
||||
}
|
||||
values_.erase(key);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
void *CPUHashTable<Key, Value>::AllocateMemory(size_t size) const {
|
||||
return CPUMemoryPool::GetInstance().AllocTensorMem(size, false);
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
void CPUHashTable<Key, Value>::FreeMemory(void *ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
CPUMemoryPool::GetInstance().FreeTensorMem(ptr);
|
||||
}
|
||||
|
||||
template class CPUHashTable<int32_t, float>;
|
||||
template class CPUHashTable<int64_t, float>;
|
||||
} // namespace cpu
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
#include <shared_mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/device/hash_table.h"
|
||||
#include "plugin/device/cpu/hal/hardware/cpu_memory_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -64,6 +64,12 @@ class CPUHashTable : public HashTable<Key, Value> {
|
|||
bool Clear() override;
|
||||
|
||||
private:
|
||||
// Allocate host memory from dynamic memory pool.
|
||||
void *AllocateMemory(size_t size) const;
|
||||
|
||||
// Free host memory to dynamic memory pool.
|
||||
void FreeMemory(void *ptr) const;
|
||||
|
||||
// The key-value style elements stored in this hash table.
|
||||
std::unordered_map<Key, Value *> values_;
|
||||
|
||||
|
@ -78,9 +84,6 @@ class CPUHashTable : public HashTable<Key, Value> {
|
|||
// The flag records whether the elements of the hash table have changed since the last export, true means that there
|
||||
// has been a change.
|
||||
bool is_dirty_{true};
|
||||
|
||||
// The device context is used to allocate host side memory from the pool for values in the hash table.
|
||||
DeviceContext *device_ctx_{nullptr};
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
|
|
Loading…
Reference in New Issue