forked from mindspore-Ecosystem/mindspore
!11293 PS Cache ParseData Parallel
From: @gaoyong10 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
16d19f2d26
|
@ -394,11 +394,79 @@ bool PsCacheManager::ProcessData() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
||||
bool *in_device, size_t *hash_hit_count) {
|
||||
MS_ERROR_IF_NULL(batch_ids);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
MS_ERROR_IF_NULL(in_device);
|
||||
MS_ERROR_IF_NULL(hash_hit_count);
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
auto &device_hash_map = embedding_device_cache_->device_hash_map_;
|
||||
MS_ERROR_IF_NULL(device_hash_map);
|
||||
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
|
||||
|
||||
for (size_t i = 0; i < batch_ids_len; ++i) {
|
||||
auto iter = hash_id_to_index.find(batch_ids[i]);
|
||||
if (iter != hash_id_to_index.end()) {
|
||||
hash_index[i] = iter->second;
|
||||
if (device_hash_map->hash_step(iter->second) != data_step_) {
|
||||
++(*hash_hit_count);
|
||||
device_hash_map->set_hash_step(iter->second, data_step_);
|
||||
}
|
||||
in_device[i] = true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
|
||||
bool *in_device) {
|
||||
MS_ERROR_IF_NULL(batch_ids);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
MS_ERROR_IF_NULL(in_device);
|
||||
|
||||
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
size_t hash_hit_count[kMaxThreadNum] = {0};
|
||||
size_t i = 0;
|
||||
size_t task_offset = 0;
|
||||
|
||||
for (; i < thread_num; ++i) {
|
||||
if (task_offset >= batch_ids_len) {
|
||||
break;
|
||||
}
|
||||
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
|
||||
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens,
|
||||
hash_index + task_offset, in_device + task_offset, hash_hit_count + i);
|
||||
task_offset += task_proc_lens;
|
||||
}
|
||||
if (task_offset != batch_ids_len) {
|
||||
MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
threads[j].join();
|
||||
}
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
statistics_info_.hash_hit_count_ += hash_hit_count[j];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
|
||||
MS_ERROR_IF_NULL(batch_ids);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
statistics_info_.batch_id_count_ = batch_ids_len;
|
||||
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]);
|
||||
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
|
||||
MS_LOG(EXCEPTION) << "Data in device memset failed.";
|
||||
}
|
||||
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
|
||||
for (size_t i = 0; i < batch_ids_len; i++) {
|
||||
if (in_device[i]) {
|
||||
continue;
|
||||
}
|
||||
bool need_swap_host_to_device = true;
|
||||
bool need_swap_device_to_host = true;
|
||||
auto id = batch_ids[i];
|
||||
|
@ -585,10 +653,10 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
|
|||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
|
||||
size_t i;
|
||||
size_t i = 0;
|
||||
size_t task_offset = 0;
|
||||
MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
for (; i < thread_num; i++) {
|
||||
if (task_offset >= indices_lens) {
|
||||
break;
|
||||
}
|
||||
|
@ -613,7 +681,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
|
|||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
|
||||
size_t i;
|
||||
size_t i = 0;
|
||||
size_t task_offset = 0;
|
||||
|
||||
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
|
||||
|
@ -632,7 +700,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
|
|||
}
|
||||
};
|
||||
|
||||
for (i = 0; i < thread_num; i++) {
|
||||
for (; i < thread_num; i++) {
|
||||
if (task_offset >= insert_indices_size) {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
constexpr size_t kHostCacheScaleFactor = 10;
|
||||
constexpr size_t kMaxThreadNum = 16;
|
||||
constexpr size_t kMinIdsPerThread = 10000;
|
||||
using mindspore::kernel::Address;
|
||||
|
||||
struct HashTableInfo {
|
||||
|
@ -169,7 +170,9 @@ class PsCacheManager {
|
|||
void DumpStatisticsInfo(size_t each_print_step = 1000);
|
||||
bool SyncHostEmbeddingTable();
|
||||
bool SyncDeviceEmbeddingTable();
|
||||
|
||||
bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
|
||||
size_t *hash_hit_count);
|
||||
bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device);
|
||||
bool initialized_ps_cache_{false};
|
||||
std::string channel_name_;
|
||||
std::mutex channel_mutex_;
|
||||
|
|
Loading…
Reference in New Issue