From 798cca42e5cb24c3cb2584e675bb24dc98d10884 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Wed, 30 Dec 2020 20:56:51 +0800 Subject: [PATCH] ps cache code review --- .../ps/ps_cache/ascend/ascend_ps_cache.cc | 14 ++++++++------ .../ps/ps_cache/ascend/ascend_ps_cache.h | 4 ++-- .../ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h | 4 ++-- mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h | 4 ++-- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 19 +++++++++++-------- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc index f91d8cb4074..0eb8c02cf0a 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc @@ -206,7 +206,7 @@ bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { } bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, - size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { + size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) { MS_ERROR_IF_NULL(hash_table_addr); MS_ERROR_IF_NULL(swap_out_value_addr); MS_ERROR_IF_NULL(swap_out_index_addr); @@ -217,7 +217,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr std::vector> output_shape; std::vector input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}; std::vector output_type = {TypeId::kNumberTypeFloat32}; - input_shape.push_back({hash_table_size, embedding_size}); + input_shape.push_back({cache_vocab_size, embedding_size}); input_shape.push_back({swap_out_size}); input_shape.push_back({1}); output_shape.push_back({swap_out_size, embedding_size}); @@ -229,7 +229,8 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr AddressPtrList kernel_outputs = { std::make_shared
(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))}; AddressPtrList kernel_workspaces; - kernel_inputs.push_back(std::make_shared
(hash_table_addr, hash_table_size * embedding_size * sizeof(float))); + kernel_inputs.push_back( + std::make_shared
(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float))); kernel_inputs.push_back(std::make_shared
(swap_out_index_addr, swap_out_size * sizeof(int))); kernel_inputs.push_back(std::make_shared
(offset_addr_, sizeof(int))); auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); @@ -241,7 +242,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr } bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, - size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { + size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) { MS_ERROR_IF_NULL(hash_table_addr); MS_ERROR_IF_NULL(swap_in_value_addr); MS_ERROR_IF_NULL(swap_in_index_addr); @@ -253,7 +254,7 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, std::vector input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32}; std::vector output_type = {TypeId::kNumberTypeInt32}; - input_shape.push_back({hash_table_size, embedding_size}); + input_shape.push_back({cache_vocab_size, embedding_size}); input_shape.push_back({swap_in_size}); input_shape.push_back({swap_in_size, embedding_size}); input_shape.push_back({1}); @@ -265,7 +266,8 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, AddressPtrList kernel_inputs; AddressPtrList kernel_outputs; AddressPtrList kernel_workspaces; - kernel_inputs.push_back(std::make_shared
(hash_table_addr, hash_table_size * embedding_size * sizeof(float))); + kernel_inputs.push_back( + std::make_shared
(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float))); kernel_inputs.push_back(std::make_shared
(swap_in_index_addr, swap_in_size * sizeof(int))); kernel_inputs.push_back(std::make_shared
(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float))); kernel_inputs.push_back(std::make_shared
(cache_vocab_size_addr_, sizeof(int))); diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h index db4bec840b9..4dc07bf3bba 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h @@ -57,9 +57,9 @@ class AscendPsCache : public PsCacheBasic { bool SynchronizeStream() override; bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; - bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, + bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) override; - bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, + bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) override; private: diff --git a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h index a0bfbd951f2..45678d32f55 100644 --- a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h +++ b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h @@ -35,9 +35,9 @@ class GPUPsCache : public PsCacheBasic { bool SynchronizeStream() override; bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; - bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, + bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) override; - bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, + bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) override; private: diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h index 33713bb1085..b7a98b2914a 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h @@ -41,9 +41,9 @@ class PsCacheBasic { virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, - size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; + size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0; virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, - size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; + size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) = 0; protected: void *stream_; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 050065b211c..85d070daf36 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -660,7 +660,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { } auto embedding_size = hash_info.embedding_size; auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); - auto hash_table_size = hash_info.device_address.size; + auto cache_vocab_size = hash_info.cache_vocab_size; auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, @@ -673,7 +673,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { swap_indices_size * sizeof(int))); RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, - hash_table_size, embedding_size, swap_indices_size)); + cache_vocab_size, embedding_size, swap_indices_size)); RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); return true; } @@ -689,7 +689,7 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { return true; } auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); - auto hash_table_size = hash_info.device_address.size; + auto cache_vocab_size = hash_info.cache_vocab_size; auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); auto embedding_size = hash_info.embedding_size; auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); @@ -698,7 +698,7 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { swap_indices_size * sizeof(int))); RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, - hash_table_size, embedding_size, swap_indices_size)); + cache_vocab_size, embedding_size, swap_indices_size)); RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_, swap_indices_size * embedding_size * sizeof(float))); @@ -770,14 +770,14 @@ bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray return true; } auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); - auto hash_table_size = hash_info.device_address.size; + auto cache_vocab_size = hash_info.cache_vocab_size; auto embedding_size = hash_info.embedding_size; swap_out_data->resize(swap_out_index_size * embedding_size); RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int))); RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, - hash_table_size, embedding_size, swap_out_index_size)); + cache_vocab_size, embedding_size, swap_out_index_size)); RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_, swap_out_index_size * embedding_size * sizeof(float))); @@ -796,7 +796,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons return true; } auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); - auto hash_table_size = hash_info.device_address.size; + auto cache_vocab_size = hash_info.cache_vocab_size; auto embedding_size = hash_info.embedding_size; // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). ::ps::SArray lengths{swap_in_ids_size}; @@ -817,7 +817,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons swap_in_index, swap_in_ids_size * sizeof(int))); RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, - hash_table_size, embedding_size, swap_in_ids_size)); + cache_vocab_size, embedding_size, swap_in_ids_size)); return true; } @@ -846,6 +846,9 @@ void PsCacheManager::SyncEmbeddingTable() { if (finish_embedding_table_sync_) { return; } + if (!initialized_ps_cache_) { + return; + } if (!SyncHostEmbeddingTable()) { MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; }