!11293 PS Cache ParseData Parallel

From: @gaoyong10
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-18 19:51:11 +08:00 committed by Gitee
commit 16d19f2d26
2 changed files with 76 additions and 5 deletions

View File

@ -394,11 +394,79 @@ bool PsCacheManager::ProcessData() {
return true;
}
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
bool *in_device, size_t *hash_hit_count) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
MS_ERROR_IF_NULL(in_device);
MS_ERROR_IF_NULL(hash_hit_count);
MS_ERROR_IF_NULL(embedding_device_cache_);
auto &device_hash_map = embedding_device_cache_->device_hash_map_;
MS_ERROR_IF_NULL(device_hash_map);
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
for (size_t i = 0; i < batch_ids_len; ++i) {
auto iter = hash_id_to_index.find(batch_ids[i]);
if (iter != hash_id_to_index.end()) {
hash_index[i] = iter->second;
if (device_hash_map->hash_step(iter->second) != data_step_) {
++(*hash_hit_count);
device_hash_map->set_hash_step(iter->second, data_step_);
}
in_device[i] = true;
}
}
return true;
}
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
bool *in_device) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
MS_ERROR_IF_NULL(in_device);
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t hash_hit_count[kMaxThreadNum] = {0};
size_t i = 0;
size_t task_offset = 0;
for (; i < thread_num; ++i) {
if (task_offset >= batch_ids_len) {
break;
}
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens,
hash_index + task_offset, in_device + task_offset, hash_hit_count + i);
task_offset += task_proc_lens;
}
if (task_offset != batch_ids_len) {
MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset;
}
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
for (size_t j = 0; j < i; j++) {
statistics_info_.hash_hit_count_ += hash_hit_count[j];
}
return true;
}
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
statistics_info_.batch_id_count_ = batch_ids_len;
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]);
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
MS_LOG(EXCEPTION) << "Data in device memset failed.";
}
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
for (size_t i = 0; i < batch_ids_len; i++) {
if (in_device[i]) {
continue;
}
bool need_swap_host_to_device = true;
bool need_swap_device_to_host = true;
auto id = batch_ids[i];
@ -585,10 +653,10 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num;
size_t i;
size_t i = 0;
size_t task_offset = 0;
MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens;
for (i = 0; i < thread_num; i++) {
for (; i < thread_num; i++) {
if (task_offset >= indices_lens) {
break;
}
@ -613,7 +681,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num;
size_t i;
size_t i = 0;
size_t task_offset = 0;
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
@ -632,7 +700,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
}
};
for (i = 0; i < thread_num; i++) {
for (; i < thread_num; i++) {
if (task_offset >= insert_indices_size) {
break;
}

View File

@ -40,6 +40,7 @@ namespace mindspore {
namespace ps {
constexpr size_t kHostCacheScaleFactor = 10;
constexpr size_t kMaxThreadNum = 16;
constexpr size_t kMinIdsPerThread = 10000;
using mindspore::kernel::Address;
struct HashTableInfo {
@ -169,7 +170,9 @@ class PsCacheManager {
void DumpStatisticsInfo(size_t each_print_step = 1000);
bool SyncHostEmbeddingTable();
bool SyncDeviceEmbeddingTable();
bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
size_t *hash_hit_count);
bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device);
bool initialized_ps_cache_{false};
std::string channel_name_;
std::mutex channel_mutex_;