forked from mindspore-Ecosystem/mindspore
!10991 ps cache performance refine
From: @zyli2020 Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
16a7265c91
mindspore/ccsrc/ps/ps_cache
|
@ -47,10 +47,6 @@ class EmbeddingHashMap {
|
|||
virtual ~EmbeddingHashMap() = default;
|
||||
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
|
||||
const size_t graph_running_step, size_t *swap_out_size);
|
||||
std::unordered_map<int, int>::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); }
|
||||
bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const {
|
||||
return iter != hash_id_to_index_.end();
|
||||
}
|
||||
size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
|
||||
void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); }
|
||||
const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; }
|
||||
|
|
|
@ -437,12 +437,13 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host,
|
|||
MS_ERROR_IF_NULL(need_swap_host_to_device);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
auto device_hash_map = embedding_device_cache_->device_hash_map_;
|
||||
auto &device_hash_map = embedding_device_cache_->device_hash_map_;
|
||||
MS_ERROR_IF_NULL(device_hash_map);
|
||||
|
||||
int index = INVALID_INDEX_VALUE;
|
||||
auto iter = device_hash_map->id_iter(id);
|
||||
if (device_hash_map->IsIdExist(iter)) {
|
||||
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
|
||||
const auto &iter = hash_id_to_index.find(id);
|
||||
if (iter != hash_id_to_index.end()) {
|
||||
*need_swap_device_to_host = false;
|
||||
*need_swap_host_to_device = false;
|
||||
index = iter->second;
|
||||
|
@ -482,11 +483,12 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
|||
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
|
||||
MS_ERROR_IF_NULL(host_to_device_index);
|
||||
auto host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
auto &host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
MS_ERROR_IF_NULL(host_hash_map);
|
||||
|
||||
auto iter = host_hash_map->id_iter(id);
|
||||
if (host_hash_map->IsIdExist(iter)) {
|
||||
const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
|
||||
const auto &iter = hash_id_to_index.find(id);
|
||||
if (iter != hash_id_to_index.end()) {
|
||||
auto index = iter->second;
|
||||
if (host_hash_map->hash_step(index) != data_step_) {
|
||||
host_hash_map->set_hash_step(index, data_step_);
|
||||
|
@ -522,11 +524,12 @@ bool PsCacheManager::ParseHostDataDeviceToHost() {
|
|||
MS_ERROR_IF_NULL(device_to_host_ids);
|
||||
MS_ERROR_IF_NULL(device_to_host_index);
|
||||
|
||||
auto host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
auto &host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
MS_ERROR_IF_NULL(host_hash_map);
|
||||
int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
|
||||
auto iter = host_hash_map->id_iter(swap_device_to_host_id);
|
||||
if (host_hash_map->IsIdExist(iter)) {
|
||||
const auto &hash_id_to_index = host_hash_map->hash_id_to_index();
|
||||
const auto &iter = hash_id_to_index.find(swap_device_to_host_id);
|
||||
if (iter != hash_id_to_index.end()) {
|
||||
auto index = iter->second;
|
||||
if (host_hash_map->hash_step(index) != data_step_) {
|
||||
host_hash_map->set_hash_step(index, data_step_);
|
||||
|
|
Loading…
Reference in New Issue