diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index c995b92aee5..fea180b7b43 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1085,10 +1085,10 @@ void ClearResAtexit() { session::ClearPythonParasMap(); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { - ps::worker.Finalize(); if (ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.Finalize(); } + ps::worker.Finalize(); } #endif ad::g_k_prims.clear(); diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index 3b39c56bdf0..7d7fb8c8c59 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -552,7 +552,6 @@ template void ParameterServer::Finalize() { running_ = false; apply_grads_cv_.notify_one(); - SyncEmbeddingTables(); } template @@ -774,7 +773,7 @@ void ParameterServer::GetEmbeddingTableParamPtr() { for (auto cnode : cnodes) { MS_EXCEPTION_IF_NULL(cnode); std::string cnode_name = AnfAlgo::GetCNodeName(cnode); - if (cnode_name == kEmbeddingLookupOpName) { + if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) { auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); MS_EXCEPTION_IF_NULL(embedding_table); MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; @@ -832,6 +831,7 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { Init(func_graph); PSContext::instance()->SetPSRankId(rank_id_); thread_->join(); + SyncEmbeddingTables(); MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; ::ps::Finalize(0, true); MS_LOG(INFO) << "PServer finalized successfully."; diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc index 00c020ef958..4f5c298b908 100755 --- a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc @@ -30,21 +30,21 @@ int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out if (loop++ == hash_capacity_) { return INVALID_INDEX_VALUE; } - if (hash_map_unit_[hash_index].IsEmpty()) { + if (hash_map_elements_[hash_index].IsEmpty()) { hash_count_++; (void)hash_id_to_index_.emplace(id, hash_index); - hash_map_unit_[hash_index].set_id(id); - hash_map_unit_[hash_index].set_step(data_step); + hash_map_elements_[hash_index].set_id(id); + hash_map_elements_[hash_index].set_step(data_step); return hash_index; - } else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) { + } else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) { // Need swap out from the hash table. swap_out_index[*swap_out_size] = hash_index; - swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_; + swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; (*swap_out_size)++; - (void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_); + (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); (void)hash_id_to_index_.emplace(id, hash_index); - hash_map_unit_[hash_index].set_id(id); - hash_map_unit_[hash_index].set_step(data_step); + hash_map_elements_[hash_index].set_id(id); + hash_map_elements_[hash_index].set_step(data_step); return hash_index; } hash_index = (hash_index + 1) % hash_capacity_; @@ -58,9 +58,10 @@ void EmbeddingHashMap::DumpHashMap() { MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; } MS_LOG(INFO) << "Dump hash_map_unit: "; - for (size_t i = 0; i < hash_map_unit_.size(); i++) { - if (!hash_map_unit_[i].IsEmpty()) { - MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_; + for (size_t i = 0; i < hash_map_elements_.size(); i++) { + if (!hash_map_elements_[i].IsEmpty()) { + MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_ + << " step: " << hash_map_elements_[i].step_; } } MS_LOG(INFO) << "Dump hash map info end."; diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h index 8df58c09279..5950d3c937c 100644 --- a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h @@ -30,8 +30,8 @@ static const size_t INVALID_STEP_VALUE = 0; static const int INVALID_INDEX_VALUE = -1; struct HashMapElement { - int id_; - size_t step_; + int id_{INVALID_INDEX_VALUE}; + size_t step_{INVALID_STEP_VALUE}; bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } void set_id(int id) { id_ = id; } @@ -42,7 +42,7 @@ struct HashMapElement { class EmbeddingHashMap { public: EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { - hash_map_unit_.resize(hash_capacity); + hash_map_elements_.resize(hash_capacity); } virtual ~EmbeddingHashMap() = default; int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, @@ -51,8 +51,10 @@ class EmbeddingHashMap { bool IsIdExist(const std::unordered_map::const_iterator iter) const { return iter != hash_id_to_index_.end(); } - size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; } - void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); } + size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } + void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } + const std::unordered_map &hash_id_to_index() const { return hash_id_to_index_; } + size_t hash_capacity() const { return hash_capacity_; } void DumpHashMap(); private: @@ -60,7 +62,7 @@ class EmbeddingHashMap { bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } size_t hash_count_; size_t hash_capacity_; - std::vector hash_map_unit_; + std::vector hash_map_elements_; std::unordered_map hash_id_to_index_; }; } // namespace ps diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 3b360bfb53c..d55c56dab95 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -226,9 +226,9 @@ void PsCacheManager::AllocMemForHashTable() { device_address.addr = addr; auto &host_address = item.second.host_address; - auto host_address_ptr = new int[host_cache_vocab_size_ * embedding_size]; + auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size]; MS_EXCEPTION_IF_NULL(host_address_ptr); - host_address = std::shared_ptr(host_address_ptr, std::default_delete()); + host_address = std::shared_ptr(host_address_ptr, std::default_delete()); MS_EXCEPTION_IF_NULL(host_address); max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; @@ -330,6 +330,14 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { } void PsCacheManager::Finalize() { + if (running_) { + if (!SyncHostEmbeddingTable()) { + MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; + } + if (!SyncDeviceEmbeddingTable()) { + MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; + } + } running_ = false; PsDataPrefetch::GetInstance().NotifyFinalize(); insert_init_info_.notify_all(); @@ -838,6 +846,99 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_da return true; } +bool PsCacheManager::SyncHostEmbeddingTable() { + MS_ERROR_IF_NULL(embedding_host_cache_); + const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); + size_t swap_indices_lens = hash_id_to_index.size(); + if (swap_indices_lens == 0) { + return true; + } + std::unique_ptr host_to_server_ids_ptr = std::make_unique(swap_indices_lens); + MS_ERROR_IF_NULL(host_to_server_ids_ptr); + std::unique_ptr host_to_server_indices_ptr = std::make_unique(swap_indices_lens); + MS_ERROR_IF_NULL(host_to_server_indices_ptr); + size_t idx = 0; + for (const auto &item : hash_id_to_index) { + host_to_server_ids_ptr[idx] = item.first; + host_to_server_indices_ptr[idx++] = item.second; + } + for (const auto &item : hash_tables_) { + const auto &hash_info = item.second; + if (hash_info.param_init_info_.param_type_ != kWeight) { + continue; + } + auto key = worker.GetParamKey(item.first); + ::ps::SArray lookup_ids(swap_indices_lens, 0); + ::ps::SArray swap_out_data; + auto embedding_size = hash_info.embedding_size; + swap_out_data.resize(swap_indices_lens * embedding_size); + auto host_hash_table_addr = hash_info.host_address.get(); + MS_ERROR_IF_NULL(host_hash_table_addr); + RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr, + host_to_server_indices_ptr.get(), swap_out_data.data())); + + auto copy_len = swap_indices_lens * sizeof(int); + auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids_ptr.get(), copy_len); + if (ret != EOK) { + MS_LOG(ERROR) << "Lookup id memcpy failed."; + return false; + } + worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + } + return true; +} + +bool PsCacheManager::SyncDeviceEmbeddingTable() { + MS_ERROR_IF_NULL(embedding_device_cache_); + const auto &device_hash_map = embedding_device_cache_->device_hash_map_; + const auto &hash_id_to_index = device_hash_map->hash_id_to_index(); + size_t swap_indices_lens = hash_id_to_index.size(); + if (swap_indices_lens == 0) { + return true; + } + std::unique_ptr device_to_server_ids_ptr = std::make_unique(swap_indices_lens); + MS_ERROR_IF_NULL(device_to_server_ids_ptr); + std::unique_ptr device_to_server_indices_ptr = std::make_unique(swap_indices_lens); + MS_ERROR_IF_NULL(device_to_server_indices_ptr); + size_t idx = 0; + for (const auto &item : hash_id_to_index) { + device_to_server_ids_ptr[idx] = item.first; + device_to_server_indices_ptr[idx++] = item.second; + } + for (const auto &item : hash_tables_) { + const auto &hash_info = item.second; + if (hash_info.param_init_info_.param_type_ != kWeight) { + continue; + } + auto key = worker.GetParamKey(item.first); + ::ps::SArray lookup_ids(swap_indices_lens, 0); + ::ps::SArray swap_out_data; + auto embedding_size = hash_info.embedding_size; + swap_out_data.resize(swap_indices_lens * embedding_size); + std::unique_ptr device_hash_table_addr_tmp = + std::make_unique(device_hash_map->hash_capacity() * embedding_size); + MS_ERROR_IF_NULL(device_hash_table_addr_tmp); + + auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); + MS_ERROR_IF_NULL(hash_table_addr); + auto hash_table_size = hash_info.device_address.size; + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(device_hash_table_addr_tmp.get(), + hash_table_addr, hash_table_size)); + RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); + RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(), + device_to_server_indices_ptr.get(), swap_out_data.data())); + + auto copy_len = swap_indices_lens * sizeof(int); + auto ret = memcpy_s(lookup_ids.data(), copy_len, device_to_server_ids_ptr.get(), copy_len); + if (ret != EOK) { + MS_LOG(ERROR) << "Lookup id memcpy failed."; + return false; + } + worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + } + return true; +} + void PsCacheManager::DumpHashTables(bool dump_device_tables) const { for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index bbf1db65192..5e1732d08d7 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -48,7 +48,7 @@ struct HashTableInfo { size_t embedding_size{0}; size_t vocab_size{0}; Address device_address{nullptr, 0}; - std::shared_ptr host_address{nullptr}; + std::shared_ptr host_address{nullptr}; ParamInitInfo param_init_info_; }; @@ -166,6 +166,8 @@ class PsCacheManager { bool CheckFinishInsertInitInfo() const; void AddEmbeddingTable() const; void DumpStatisticsInfo(size_t each_print_step = 1000); + bool SyncHostEmbeddingTable(); + bool SyncDeviceEmbeddingTable(); bool initialized_ps_cache_{false}; std::string channel_name_; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 11f8b5d578e..78960e2b42f 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -205,6 +205,7 @@ constexpr auto kPushOpName = "Push"; constexpr auto kPullOpName = "Pull"; constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; +constexpr auto kGatherV2OpName = "GatherV2"; constexpr auto kPaddingOpName = "Padding"; constexpr auto kAvgPoolOpName = "AvgPool"; constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 2c2724468f0..90945c15748 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -292,7 +292,8 @@ class EmbeddingLookup(Cell): "in 'full_batch' and 'table_row_slice' parallel strategy.") self.vocab_cache_size = self.vocab_cache_size * device_num self.cache_enable = True - self.vocab_size = self.vocab_cache_size + if _is_role_worker(): + self.vocab_size = self.vocab_cache_size def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): """PS embeddingLookup cache enable set.""" diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py index 21b211bb046..0bad048a6ea 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py @@ -24,6 +24,7 @@ from mindspore.context import ParallelMode from mindspore.communication.management import get_rank, get_group_size, init from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from mindspore.common import set_seed +from mindspore.parallel._ps_context import _is_role_worker from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.callbacks import LossCallBack, EvalCallBack @@ -117,11 +118,14 @@ def train_and_eval(config): eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) - if cache_enable: - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, - keep_checkpoint_max=5, integrated_save=False) + if _is_role_worker(): + if cache_enable: + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, + keep_checkpoint_max=1, integrated_save=False) + else: + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) else: - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py index e195b1360b9..a5a868e6a16 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py @@ -20,6 +20,7 @@ import sys from mindspore import Model, context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.common import set_seed +from mindspore.parallel._ps_context import _is_role_worker from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.callbacks import LossCallBack, EvalCallBack @@ -99,7 +100,14 @@ def train_and_eval(config): eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + if _is_role_worker(): + if cache_enable: + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size() * epochs, + keep_checkpoint_max=1) + else: + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + else: + ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=1) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]