From 600f95553a897155147e762be4453bab216b690b Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Fri, 8 Jan 2021 10:19:27 +0800 Subject: [PATCH] ps cache parse parallel --- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 76 ++++++++++++++++++- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 5 +- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 9d040295c54..415f51c0291 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -394,11 +394,79 @@ bool PsCacheManager::ProcessData() { return true; } +bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, + bool *in_device, size_t *hash_hit_count) { + MS_ERROR_IF_NULL(batch_ids); + MS_ERROR_IF_NULL(hash_index); + MS_ERROR_IF_NULL(in_device); + MS_ERROR_IF_NULL(hash_hit_count); + MS_ERROR_IF_NULL(embedding_device_cache_); + auto &device_hash_map = embedding_device_cache_->device_hash_map_; + MS_ERROR_IF_NULL(device_hash_map); + const auto &hash_id_to_index = device_hash_map->hash_id_to_index(); + + for (size_t i = 0; i < batch_ids_len; ++i) { + auto iter = hash_id_to_index.find(batch_ids[i]); + if (iter != hash_id_to_index.end()) { + hash_index[i] = iter->second; + if (device_hash_map->hash_step(iter->second) != data_step_) { + ++(*hash_hit_count); + device_hash_map->set_hash_step(iter->second, data_step_); + } + in_device[i] = true; + } + } + return true; +} + +bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, + bool *in_device) { + MS_ERROR_IF_NULL(batch_ids); + MS_ERROR_IF_NULL(hash_index); + MS_ERROR_IF_NULL(in_device); + + size_t thread_num = batch_ids_len / kMinIdsPerThread + 1; + thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; + std::thread threads[kMaxThreadNum]; + size_t hash_hit_count[kMaxThreadNum] = {0}; + size_t i = 0; + size_t task_offset = 0; + + for (; i < thread_num; ++i) { + if (task_offset >= batch_ids_len) { + break; + } + size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0); + threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens, + hash_index + task_offset, in_device + task_offset, hash_hit_count + i); + task_offset += task_proc_lens; + } + if (task_offset != batch_ids_len) { + MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset; + } + + for (size_t j = 0; j < i; j++) { + threads[j].join(); + } + for (size_t j = 0; j < i; j++) { + statistics_info_.hash_hit_count_ += hash_hit_count[j]; + } + return true; +} + bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { MS_ERROR_IF_NULL(batch_ids); MS_ERROR_IF_NULL(hash_index); statistics_info_.batch_id_count_ = batch_ids_len; + std::unique_ptr in_device(new bool[batch_ids_len]); + if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) { + MS_LOG(EXCEPTION) << "Data in device memset failed."; + } + CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); for (size_t i = 0; i < batch_ids_len; i++) { + if (in_device[i]) { + continue; + } bool need_swap_host_to_device = true; bool need_swap_device_to_host = true; auto id = batch_ids[i]; @@ -585,10 +653,10 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; std::thread threads[kMaxThreadNum]; size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num; - size_t i; + size_t i = 0; size_t task_offset = 0; MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens; - for (i = 0; i < thread_num; i++) { + for (; i < thread_num; i++) { if (task_offset >= indices_lens) { break; } @@ -613,7 +681,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; std::thread threads[kMaxThreadNum]; size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num; - size_t i; + size_t i = 0; size_t task_offset = 0; auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, @@ -632,7 +700,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in } }; - for (i = 0; i < thread_num; i++) { + for (; i < thread_num; i++) { if (task_offset >= insert_indices_size) { break; } diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index 54dfa217f93..851f2b6c572 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -40,6 +40,7 @@ namespace mindspore { namespace ps { constexpr size_t kHostCacheScaleFactor = 10; constexpr size_t kMaxThreadNum = 16; +constexpr size_t kMinIdsPerThread = 10000; using mindspore::kernel::Address; struct HashTableInfo { @@ -169,7 +170,9 @@ class PsCacheManager { void DumpStatisticsInfo(size_t each_print_step = 1000); bool SyncHostEmbeddingTable(); bool SyncDeviceEmbeddingTable(); - + bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, + size_t *hash_hit_count); + bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); bool initialized_ps_cache_{false}; std::string channel_name_; std::mutex channel_mutex_;