ps cache performance refine

From: @zyli2020
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-01-07 11:58:53 +08:00 committed by Gitee
commit 16a7265c91
2 changed files with 12 additions and 13 deletions
mindspore/ccsrc/ps/ps_cache

View File

@ -47,10 +47,6 @@ class EmbeddingHashMap {
virtual ~EmbeddingHashMap() = default;
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
const size_t graph_running_step, size_t *swap_out_size);
std::unordered_map<int, int>::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); }
bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const {
return iter != hash_id_to_index_.end();
}
size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); }
const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; }

View File

@ -437,12 +437,13 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host,
MS_ERROR_IF_NULL(need_swap_host_to_device);
MS_ERROR_IF_NULL(hash_index);
MS_ERROR_IF_NULL(embedding_device_cache_);
auto device_hash_map = embedding_device_cache_->device_hash_map_;
auto &device_hash_map = embedding_device_cache_->device_hash_map_;
MS_ERROR_IF_NULL(device_hash_map);
int index = INVALID_INDEX_VALUE;
auto iter = device_hash_map->id_iter(id);
if (device_hash_map->IsIdExist(iter)) {
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
const auto &iter = hash_id_to_index.find(id);
if (iter != hash_id_to_index.end()) {
*need_swap_device_to_host = false;
*need_swap_host_to_device = false;
index = iter->second;
@ -482,11 +483,12 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
MS_ERROR_IF_NULL(embedding_host_cache_);
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
MS_ERROR_IF_NULL(host_to_device_index);
auto host_hash_map = embedding_host_cache_->host_hash_map_;
auto &host_hash_map = embedding_host_cache_->host_hash_map_;
MS_ERROR_IF_NULL(host_hash_map);
auto iter = host_hash_map->id_iter(id);
if (host_hash_map->IsIdExist(iter)) {
const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
const auto &iter = hash_id_to_index.find(id);
if (iter != hash_id_to_index.end()) {
auto index = iter->second;
if (host_hash_map->hash_step(index) != data_step_) {
host_hash_map->set_hash_step(index, data_step_);
@ -522,11 +524,12 @@ bool PsCacheManager::ParseHostDataDeviceToHost() {
MS_ERROR_IF_NULL(device_to_host_ids);
MS_ERROR_IF_NULL(device_to_host_index);
auto host_hash_map = embedding_host_cache_->host_hash_map_;
auto &host_hash_map = embedding_host_cache_->host_hash_map_;
MS_ERROR_IF_NULL(host_hash_map);
int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
auto iter = host_hash_map->id_iter(swap_device_to_host_id);
if (host_hash_map->IsIdExist(iter)) {
const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
const auto &iter = hash_id_to_index.find(swap_device_to_host_id);
if (iter != hash_id_to_index.end()) {
auto index = iter->second;
if (host_hash_map->hash_step(index) != data_step_) {
host_hash_map->set_hash_step(index, data_step_);