diff --git a/mindspore/ccsrc/distributed/embedding_cache/embedding_cache_utils.cc b/mindspore/ccsrc/distributed/embedding_cache/embedding_cache_utils.cc new file mode 100644 index 00000000000..4a4777abae9 --- /dev/null +++ b/mindspore/ccsrc/distributed/embedding_cache/embedding_cache_utils.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "distributed/embedding_cache/embedding_cache_utils.h" +#include +#include "utils/log_adapter.h" +#include "utils/ms_utils.h" +#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__)) +#include "distributed/cluster/cluster_context.h" +#endif +#include "ps/ps_context.h" + +namespace mindspore { +namespace distributed { +EmbeddingCacheTableManager &EmbeddingCacheTableManager::GetInstance() { + static EmbeddingCacheTableManager instance{}; + return instance; +} + +void EmbeddingCacheTableManager::Initialize() { GetEmbeddingTableSliceBound(); } + +void EmbeddingCacheTableManager::Finalize() { + hash_tables_.clear(); + + embedding_device_cache_ = nullptr; + embedding_host_cache_ = nullptr; +} + +void EmbeddingCacheTableManager::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, + size_t embedding_size, size_t vocab_size, int32_t param_key) { + if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) { + MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero."; + } + hash_tables_[param_name].cache_vocab_size = cache_vocab_size; + hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor; + hash_tables_[param_name].embedding_size = embedding_size; + hash_tables_[param_name].vocab_size = vocab_size; + hash_tables_[param_name].param_key_ = param_key; + + if (vocab_size_ == 0) { + vocab_size_ = vocab_size; + } + if (device_cache_size_ == 0) { + device_cache_size_ = cache_vocab_size; + } + if (host_cache_size_ == 0) { + host_cache_size_ = cache_vocab_size * kHostCacheScaleFactor; + } +} + +void EmbeddingCacheTableManager::ReInsertHashTableSize(const std::string &new_param_name, + const std::string &cur_param_name, size_t cache_vocab_size, + size_t embedding_size) { + if (cache_vocab_size == 0 || embedding_size == 0) { + MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero."; + } + if (new_param_name.empty() || cur_param_name.empty()) { + MS_LOG(EXCEPTION) << "Parameter name can not be empty."; + } + if (new_param_name == cur_param_name) { + return; + } + auto iter = hash_tables_.find(cur_param_name); + if (iter != hash_tables_.end()) { + hash_tables_.emplace(new_param_name, iter->second); + hash_tables_.erase(iter); + } else { + hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size; + hash_tables_[new_param_name].embedding_size = embedding_size; + } +} + +void EmbeddingCacheTableManager::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, + const std::string &src_param_name, int32_t src_param_key) { + if (dest_param_name == src_param_name) { + MS_LOG(INFO) << "The dest_param_name is same as src_param_name"; + return; + } + auto iter = hash_tables_.find(src_param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed."; + } + hash_tables_.emplace(dest_param_name, iter->second); + hash_tables_[src_param_name].param_key_ = src_param_key; + hash_tables_[dest_param_name].param_key_ = dest_param_key; +} + +const Address &EmbeddingCacheTableManager::QueryHashTableAddr(const std::string ¶m_name) const { + auto iter = hash_tables_.find(param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "Can not find device address of " << param_name; + } + return iter->second.device_address; +} + +size_t EmbeddingCacheTableManager::QueryHashTableSize(const std::string ¶m_name) const { + auto iter = hash_tables_.find(param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name; + } + return iter->second.cache_vocab_size; +} + +void EmbeddingCacheTableManager::AllocMemForEmbeddingCacheTable(const device::DeviceContext *device_context) { + MS_EXCEPTION_IF_NULL(device_context); + + size_t max_embedding_size = 0; + for (auto &item : hash_tables_) { + size_t embedding_size = item.second.embedding_size; + auto &device_address = item.second.device_address; + device_address.size = device_cache_size_ * embedding_size * sizeof(float); + auto addr = device_context->AllocateMemory(device_address.size); + MS_EXCEPTION_IF_NULL(addr); + device_address.addr = addr; + + auto &host_address = item.second.host_address; + auto host_hash_table_addr = std::make_unique(host_cache_size_ * embedding_size); + MS_EXCEPTION_IF_NULL(host_hash_table_addr); + host_address = std::shared_ptr(host_hash_table_addr.release(), std::default_delete()); + MS_EXCEPTION_IF_NULL(host_address); + + max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; + } + + embedding_device_cache_ = std::make_shared(batch_ids_num_, device_cache_size_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + embedding_host_cache_ = std::make_shared(batch_ids_num_, host_cache_size_); + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + + embedding_device_cache_->hash_swap_index_addr_ = + reinterpret_cast(device_context->AllocateMemory(batch_ids_num_ * sizeof(int))); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_); + embedding_device_cache_->hash_swap_value_addr_ = + reinterpret_cast(device_context->AllocateMemory(max_embedding_size * batch_ids_num_ * sizeof(float))); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); +} + +void EmbeddingCacheTableManager::GetEmbeddingTableSliceBound() { + auto worker_num = ps::PSContext::instance()->worker_num(); + if (worker_num == 0) { + return; + } + + uint32_t rank_id = 0; +#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__)) + auto node = distributed::cluster::ClusterContext::instance()->node(); + MS_EXCEPTION_IF_NULL(node); + rank_id = node->rank_id(); +#endif + + auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); + local_embedding_slice_bounds_.first = local_shard_size * UintToInt(rank_id); + local_embedding_slice_bounds_.second = + std::min(local_embedding_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_)); + local_device_cache_bounds_.first = SizeToInt(device_cache_size_) * rank_id; + local_device_cache_bounds_.second = local_device_cache_bounds_.first + SizeToInt(device_cache_size_); + MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id + << ", id begin:" << local_embedding_slice_bounds_.first + << ", id end:" << local_embedding_slice_bounds_.second + << ", cache indices begin: " << local_device_cache_bounds_.first + << ", cache indices end: " << local_device_cache_bounds_.second; +} + +int EmbeddingCacheTableManager::cache_indices_lower_bound() const { return local_device_cache_bounds_.first; } + +void EmbeddingCacheTableManager::DumpHashTables() const { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + for (const auto &item : hash_tables_) { + const auto ¶m_name = item.first; + size_t cache_vocab_size = item.second.cache_vocab_size; + size_t host_cache_vocab_size = item.second.host_cache_vocab_size; + size_t embedding_size = item.second.embedding_size; + size_t vocab_size = item.second.vocab_size; + int32_t param_key = item.second.param_key_; + MS_LOG(INFO) << "Hash table info:" + << " param_key:" << param_key << ", embedding table name:" << param_name + << ", vocab size:" << vocab_size << ", embedding size:" << embedding_size + << ", device cache size:" << cache_vocab_size << ", host cache size:" << host_cache_vocab_size + << ", device cache address:" << reinterpret_cast(item.second.device_address.addr) + << ", host cache address:" << reinterpret_cast(item.second.host_address.get()); + } +} +} // namespace distributed +} // namespace mindspore diff --git a/mindspore/ccsrc/distributed/embedding_cache/embedding_cache_utils.h b/mindspore/ccsrc/distributed/embedding_cache/embedding_cache_utils.h new file mode 100644 index 00000000000..12041522a0e --- /dev/null +++ b/mindspore/ccsrc/distributed/embedding_cache/embedding_cache_utils.h @@ -0,0 +1,203 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_ + +#include +#include +#include +#include +#include "kernel/kernel.h" +#include "distributed/embedding_cache/embedding_hash_map.h" +#include "runtime/hardware/device_context.h" + +namespace mindspore { +namespace runtime { +class EmbeddingCachePrefetchActor; +} // namespace runtime + +namespace distributed { +// The local host cache size defaults to 10 times the device cache size. +static constexpr size_t kHostCacheScaleFactor = 10; +// The maximum number of concurrent threads for data prefetching. +static constexpr size_t kMaxThreadNum = 16; +// Maximum number of feature ids processed per thread. +static constexpr size_t kMaxIdsPerThread = 10000; + +using mindspore::kernel::Address; + +// The hash tables records information such as the dimension, memory address, and cache size of the embedding table +// with the embedding cache enabled. +struct HashTableInfo { + size_t cache_vocab_size{0}; + size_t host_cache_vocab_size{0}; + size_t embedding_size{0}; + size_t vocab_size{0}; + Address device_address{nullptr, 0}; + std::shared_ptr host_address{nullptr}; + int32_t param_key_{-1}; +}; + +// Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the +// ids information that needs to be exchanged with the local host cache. Note that the following information of +// all embedding cache tables on the device side is same: hash mapping, and feature ids of feature vectors that need +// to be swapped with the local host cache. +struct EmbeddingDeviceCache { + EmbeddingDeviceCache(size_t batch_ids_num, size_t cache_vocab_size) + : hash_swap_index_addr_(nullptr), hash_swap_value_addr_(nullptr) { + device_to_host_index = std::make_unique(batch_ids_num); + device_to_host_ids = std::make_unique(batch_ids_num); + host_to_device_index = std::make_unique(batch_ids_num); + host_to_device_ids = std::make_unique(batch_ids_num); + device_hash_map_ = std::make_shared(0, cache_vocab_size); + } + + std::unique_ptr device_to_host_index; + std::unique_ptr device_to_host_ids; + std::unique_ptr host_to_device_index; + std::unique_ptr host_to_device_ids; + int *hash_swap_index_addr_; + float *hash_swap_value_addr_; + std::shared_ptr device_hash_map_; +}; + +// Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the +// information that needs to be exchanged with the remote cache and device cache. Note that the following information of +// all embedding cache tables on the local host side is same: hash mapping, and feature ids of feature vectors that need +// to be swapped with the remote cache and device cache. +struct EmbeddingHostCache { + EmbeddingHostCache(size_t batch_ids_num, size_t host_cache_vocab_size) { + host_to_server_index = std::make_unique(batch_ids_num); + host_to_server_ids = std::make_unique(batch_ids_num); + server_to_host_index = std::make_unique(batch_ids_num); + server_to_host_ids = std::make_unique(batch_ids_num); + host_to_device_index = std::make_unique(batch_ids_num); + device_to_host_index = std::make_unique(batch_ids_num); + host_hash_map_ = std::make_shared(0, host_cache_vocab_size); + } + + std::unique_ptr host_to_server_index; + std::unique_ptr host_to_server_ids; + std::unique_ptr server_to_host_index; + std::unique_ptr server_to_host_ids; + std::unique_ptr host_to_device_index; + std::unique_ptr device_to_host_index; + std::shared_ptr host_hash_map_; +}; + +struct EmbeddingCacheStatisticsInfo { + size_t batch_id_count_{0}; + size_t batch_id_unique_count_{0}; + size_t device_to_host_size_{0}; + size_t host_to_device_size_{0}; + size_t host_to_server_size_{0}; + size_t server_to_host_size_{0}; + size_t hash_hit_count_{0}; + size_t mem_cache_swap_out_size_{0}; + size_t mem_cache_swap_in_size_{0}; + size_t mem_cache_hit_count_{0}; +}; + +// The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device +// cache size, host cache size, etc., and can allocate memory for the embedding cache table. +class EmbeddingCacheTableManager { + public: + static EmbeddingCacheTableManager &GetInstance(); + + // Initialize the EmbeddingCacheTableManager. + void Initialize(); + // Finalize the EmbeddingCacheTableManager and release all resource. + void Finalize(); + + // Insert and save dimension information of the embedding cache table. + void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, + size_t vocab_size, int32_t param_key); + + // Parameter will modify the name. After modification, you need to re-insert all the dimension information that saves + // the parameter. + void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, + size_t cache_vocab_size, size_t embedding_size); + + // Clone a hash table, such as the optimizer's state parameters are generally cloned from weight. + void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name, + int32_t src_param_key); + + // Alloc device memory for all embedding cache table. + void AllocMemForEmbeddingCacheTable(const device::DeviceContext *device_context); + + // Qeury device address of a embedding cache table. + const Address &QueryHashTableAddr(const std::string ¶m_name) const; + + // Qeury device cache size of a embedding cache table. + size_t QueryHashTableSize(const std::string ¶m_name) const; + + // Check whether a parameter is cache enabled embedding table. + bool IsEmbeddingCacheTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } + + // Set ids number of a batchsize. + void set_batch_ids_num(size_t batch_ids_num) { batch_ids_num_ = batch_ids_num; } + + // Get the offset of the id range corresponding to the embedding cache table slice on each worker in a multi-worker + // automatic parallel scenario. + int cache_indices_lower_bound() const; + + void DumpHashTables() const; + + private: + EmbeddingCacheTableManager() = default; + ~EmbeddingCacheTableManager() = default; + DISABLE_COPY_AND_ASSIGN(EmbeddingCacheTableManager); + + // Get embedding table slice bound info on each worker in a multi-worker automatic parallel scenario. + void GetEmbeddingTableSliceBound(); + + // The hash tables records information such as the dimension, memory address, and cache size of the embedding table + // with the embedding cache enabled. + std::map hash_tables_; + + // Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the + // ids information that needs to be exchanged with the local host cache. + std::shared_ptr embedding_device_cache_; + + // Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the + // information that needs to be exchanged with the remote cache and device cache. + std::shared_ptr embedding_host_cache_; + + // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range + // corresponding to the embedding table slice of the process. + std::pair local_embedding_slice_bounds_; + + // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache + // range corresponding to the embedding table slice of the process. + std::pair local_device_cache_bounds_; + + // Full Embedding table row num, not less than the total number of feature ids. + size_t vocab_size_{0}; + // Embedding cache size(row number of embedding cache) of device cache. + size_t device_cache_size_{0}; + // Embedding cache size(row number of embedding cache) of local host cache. + size_t host_cache_size_{0}; + // Total ids number of a batchsize. + size_t batch_ids_num_{0}; + + friend class mindspore::runtime::EmbeddingCachePrefetchActor; +}; +} // namespace distributed +static distributed::EmbeddingCacheTableManager &embedding_cache_table_manager = + distributed::EmbeddingCacheTableManager::GetInstance(); +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_ diff --git a/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.cc b/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.cc new file mode 100755 index 00000000000..735271f12a6 --- /dev/null +++ b/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "distributed/embedding_cache/embedding_hash_map.h" + +namespace mindspore { +namespace distributed { +int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, + const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size, + bool *const need_wait_graph) { + MS_EXCEPTION_IF_NULL(swap_out_index); + MS_EXCEPTION_IF_NULL(swap_out_ids); + MS_EXCEPTION_IF_NULL(swap_out_size); + bool need_swap = false; + auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph); + if (hash_index == INVALID_INDEX_VALUE) { + return hash_index; + } + + if (!need_swap) { + hash_count_++; + (void)hash_id_to_index_.emplace(id, hash_index); + hash_map_elements_[hash_index].set_id(id); + hash_map_elements_[hash_index].set_step(data_step); + return hash_index; + } + + swap_out_index[*swap_out_size] = hash_index; + swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; + (*swap_out_size)++; + (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); + (void)hash_id_to_index_.emplace(id, hash_index); + hash_map_elements_[hash_index].set_id(id); + hash_map_elements_[hash_index].set_step(data_step); + return hash_index; +} + +int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap, + bool *const need_wait_graph) { + MS_EXCEPTION_IF_NULL(need_swap); + MS_EXCEPTION_IF_NULL(need_wait_graph); + int hash_index = INVALID_INDEX_VALUE; + while (!expired_element_full_) { + if (hash_map_elements_[current_pos_].IsEmpty()) { + hash_index = current_pos_; + } else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) { + hash_index = current_pos_; + *need_swap = true; + } else if (hash_map_elements_[current_pos_].StepEqual(graph_running_step)) { + graph_running_index_[graph_running_index_num_++] = current_pos_; + } + current_pos_ = (current_pos_ + 1) % hash_capacity_; + if (hash_index != INVALID_INDEX_VALUE) { + return hash_index; + } + if (current_pos_ == current_batch_start_pos_) { + expired_element_full_ = true; + MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_ + << ") will be used, index swap will wait until the graph completed."; + } + } + + if (graph_running_index_pos_ != graph_running_index_num_) { + *need_swap = true; + *need_wait_graph = true; + return graph_running_index_[graph_running_index_pos_++]; + } + return INVALID_INDEX_VALUE; +} + +void EmbeddingHashMap::DumpHashMap() { + MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_; + MS_LOG(INFO) << "Dump hash_id_to_index: "; + for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) { + MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; + } + MS_LOG(INFO) << "Dump hash_map_unit: "; + for (size_t i = 0; i < hash_map_elements_.size(); i++) { + if (!hash_map_elements_[i].IsEmpty()) { + MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_ + << " step: " << hash_map_elements_[i].step_; + } + } + MS_LOG(INFO) << "Dump hash map info end."; +} + +void EmbeddingHashMap::Reset() { + current_batch_start_pos_ = current_pos_; + graph_running_index_num_ = 0; + graph_running_index_pos_ = 0; + expired_element_full_ = false; +} +} // namespace distributed +} // namespace mindspore diff --git a/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.h b/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.h new file mode 100644 index 00000000000..c019f2c4c14 --- /dev/null +++ b/mindspore/ccsrc/distributed/embedding_cache/embedding_hash_map.h @@ -0,0 +1,129 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_ + +#include +#include +#include +#include +#include "utils/hash_map.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +namespace distributed { +// Define the value of an invalid step. +static constexpr size_t INVALID_STEP_VALUE = 0; +// Define the value of an invalid index. +static constexpr int INVALID_INDEX_VALUE = -1; + +struct HashMapElement { + int id_{INVALID_INDEX_VALUE}; + // The current global step of cache prefetching operation. + size_t step_{INVALID_STEP_VALUE}; + + bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } + bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } + bool StepEqual(size_t step) const { return step_ == step; } + void set_id(int id) { id_ = id; } + void set_step(size_t step) { step_ = step; } +}; + +// EmbeddingHashMap is used to manage the id -> index mapping of the embedding cache table on the host +// side. The cache content can be stored on the device or host side. +class EmbeddingHashMap { + public: + EmbeddingHashMap(size_t hash_count, size_t hash_capacity) + : hash_count_(hash_count), + hash_capacity_(hash_capacity), + current_pos_(0), + current_batch_start_pos_(0), + graph_running_index_num_(0), + graph_running_index_pos_(0), + expired_element_full_(false) { + hash_map_elements_.resize(hash_capacity); + // In multi-device mode, embedding table are distributed on different devices by id interval, + // and ids outside the range of local device will use the front and back positions of the table, + // the positions are reserved for this. + hash_map_elements_.front().set_step(SIZE_MAX); + hash_map_elements_.back().set_step(SIZE_MAX); + graph_running_index_ = std::make_unique(hash_capacity); + } + + ~EmbeddingHashMap() = default; + + // Find the insertion position (index) in the hash map for an id. + // If the hash map capacity is insufficient, return the information of ids and indices that need to be swapped. + int ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, const size_t data_step, + const size_t graph_running_step, size_t *const swap_out_size, bool *const need_wait_graph); + + // Get the global step of a element in hash map. + size_t hash_step(const int hash_index) const { return hash_map_elements_[IntToSize(hash_index)].step_; } + // Set the global step of a element in hash map. + void set_hash_step(const int hash_index, const size_t step) { + hash_map_elements_[IntToSize(hash_index)].set_step(step); + } + + // Get the id -> index mapping. + const mindspore::HashMap &hash_id_to_index() const { return hash_id_to_index_; } + + // Get capacity of hash map. + size_t hash_capacity() const { return hash_capacity_; } + + // Reset the hash map. + void Reset(); + + void DumpHashMap(); + + private: + // Find the insertion position (index) in the hash map for an id. + int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *const need_swap, + bool *const need_wait_graph); + + // Statistics on the usage of hash map capacity. + size_t hash_count_; + + // The hash map capacity. + size_t hash_capacity_; + + // Record all elements in this hash map. + std::vector hash_map_elements_; + + // The id -> index mapping. + mindspore::HashMap hash_id_to_index_; + + // The cursor that records the current slot. + size_t current_pos_; + // The cursor that records the start position of current_pos_. + size_t current_batch_start_pos_; + + // The number of ids which need to wait for the calculation graph to finish executing the current step and need be + // swapped out. + size_t graph_running_index_num_; + // The index in array 'graph_running_index_', and the value on this index is the hash index for new id, + // but need to wait for the calculation graph to finish executing the current step and swap out the expired data. + size_t graph_running_index_pos_; + // Record the index information of the feature id that needs to be swapped out after the calculation graph finishes + // executing the current step. + std::unique_ptr graph_running_index_; + + // The flag indicates hash map is full. + bool expired_element_full_; +}; +} // namespace distributed +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_