Move embedding cache data structure to distributed directory

This commit is contained in:
lizhenyu 2022-06-05 12:52:54 +08:00
parent 7fa2215252
commit 0517feeab9
4 changed files with 636 additions and 0 deletions

View File

@ -0,0 +1,197 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "distributed/embedding_cache/embedding_cache_utils.h"
#include <algorithm>
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
#include "distributed/cluster/cluster_context.h"
#endif
#include "ps/ps_context.h"
namespace mindspore {
namespace distributed {
EmbeddingCacheTableManager &EmbeddingCacheTableManager::GetInstance() {
static EmbeddingCacheTableManager instance{};
return instance;
}
void EmbeddingCacheTableManager::Initialize() { GetEmbeddingTableSliceBound(); }
void EmbeddingCacheTableManager::Finalize() {
hash_tables_.clear();
embedding_device_cache_ = nullptr;
embedding_host_cache_ = nullptr;
}
void EmbeddingCacheTableManager::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size,
size_t embedding_size, size_t vocab_size, int32_t param_key) {
if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
}
hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
hash_tables_[param_name].embedding_size = embedding_size;
hash_tables_[param_name].vocab_size = vocab_size;
hash_tables_[param_name].param_key_ = param_key;
if (vocab_size_ == 0) {
vocab_size_ = vocab_size;
}
if (device_cache_size_ == 0) {
device_cache_size_ = cache_vocab_size;
}
if (host_cache_size_ == 0) {
host_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
}
}
void EmbeddingCacheTableManager::ReInsertHashTableSize(const std::string &new_param_name,
const std::string &cur_param_name, size_t cache_vocab_size,
size_t embedding_size) {
if (cache_vocab_size == 0 || embedding_size == 0) {
MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
}
if (new_param_name.empty() || cur_param_name.empty()) {
MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
}
if (new_param_name == cur_param_name) {
return;
}
auto iter = hash_tables_.find(cur_param_name);
if (iter != hash_tables_.end()) {
hash_tables_.emplace(new_param_name, iter->second);
hash_tables_.erase(iter);
} else {
hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size;
hash_tables_[new_param_name].embedding_size = embedding_size;
}
}
void EmbeddingCacheTableManager::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key,
const std::string &src_param_name, int32_t src_param_key) {
if (dest_param_name == src_param_name) {
MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
return;
}
auto iter = hash_tables_.find(src_param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
}
hash_tables_.emplace(dest_param_name, iter->second);
hash_tables_[src_param_name].param_key_ = src_param_key;
hash_tables_[dest_param_name].param_key_ = dest_param_key;
}
const Address &EmbeddingCacheTableManager::QueryHashTableAddr(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
}
return iter->second.device_address;
}
size_t EmbeddingCacheTableManager::QueryHashTableSize(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
}
return iter->second.cache_vocab_size;
}
void EmbeddingCacheTableManager::AllocMemForEmbeddingCacheTable(const device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
size_t max_embedding_size = 0;
for (auto &item : hash_tables_) {
size_t embedding_size = item.second.embedding_size;
auto &device_address = item.second.device_address;
device_address.size = device_cache_size_ * embedding_size * sizeof(float);
auto addr = device_context->AllocateMemory(device_address.size);
MS_EXCEPTION_IF_NULL(addr);
device_address.addr = addr;
auto &host_address = item.second.host_address;
auto host_hash_table_addr = std::make_unique<float[]>(host_cache_size_ * embedding_size);
MS_EXCEPTION_IF_NULL(host_hash_table_addr);
host_address = std::shared_ptr<float>(host_hash_table_addr.release(), std::default_delete<float[]>());
MS_EXCEPTION_IF_NULL(host_address);
max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
}
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_ids_num_, device_cache_size_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_ids_num_, host_cache_size_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
embedding_device_cache_->hash_swap_index_addr_ =
reinterpret_cast<int *>(device_context->AllocateMemory(batch_ids_num_ * sizeof(int)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_);
embedding_device_cache_->hash_swap_value_addr_ =
reinterpret_cast<float *>(device_context->AllocateMemory(max_embedding_size * batch_ids_num_ * sizeof(float)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
}
void EmbeddingCacheTableManager::GetEmbeddingTableSliceBound() {
auto worker_num = ps::PSContext::instance()->worker_num();
if (worker_num == 0) {
return;
}
uint32_t rank_id = 0;
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
auto node = distributed::cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
rank_id = node->rank_id();
#endif
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
local_embedding_slice_bounds_.first = local_shard_size * UintToInt(rank_id);
local_embedding_slice_bounds_.second =
std::min(local_embedding_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
local_device_cache_bounds_.first = SizeToInt(device_cache_size_) * rank_id;
local_device_cache_bounds_.second = local_device_cache_bounds_.first + SizeToInt(device_cache_size_);
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id
<< ", id begin:" << local_embedding_slice_bounds_.first
<< ", id end:" << local_embedding_slice_bounds_.second
<< ", cache indices begin: " << local_device_cache_bounds_.first
<< ", cache indices end: " << local_device_cache_bounds_.second;
}
int EmbeddingCacheTableManager::cache_indices_lower_bound() const { return local_device_cache_bounds_.first; }
void EmbeddingCacheTableManager::DumpHashTables() const {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t cache_vocab_size = item.second.cache_vocab_size;
size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
size_t embedding_size = item.second.embedding_size;
size_t vocab_size = item.second.vocab_size;
int32_t param_key = item.second.param_key_;
MS_LOG(INFO) << "Hash table info:"
<< " param_key:" << param_key << ", embedding table name:" << param_name
<< ", vocab size:" << vocab_size << ", embedding size:" << embedding_size
<< ", device cache size:" << cache_vocab_size << ", host cache size:" << host_cache_vocab_size
<< ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
<< ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
}
}
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,203 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
#include <map>
#include <string>
#include <memory>
#include <utility>
#include "kernel/kernel.h"
#include "distributed/embedding_cache/embedding_hash_map.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace runtime {
class EmbeddingCachePrefetchActor;
} // namespace runtime
namespace distributed {
// The local host cache size defaults to 10 times the device cache size.
static constexpr size_t kHostCacheScaleFactor = 10;
// The maximum number of concurrent threads for data prefetching.
static constexpr size_t kMaxThreadNum = 16;
// Maximum number of feature ids processed per thread.
static constexpr size_t kMaxIdsPerThread = 10000;
using mindspore::kernel::Address;
// The hash tables records information such as the dimension, memory address, and cache size of the embedding table
// with the embedding cache enabled.
struct HashTableInfo {
size_t cache_vocab_size{0};
size_t host_cache_vocab_size{0};
size_t embedding_size{0};
size_t vocab_size{0};
Address device_address{nullptr, 0};
std::shared_ptr<float> host_address{nullptr};
int32_t param_key_{-1};
};
// Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the
// ids information that needs to be exchanged with the local host cache. Note that the following information of
// all embedding cache tables on the device side is same: hash mapping, and feature ids of feature vectors that need
// to be swapped with the local host cache.
struct EmbeddingDeviceCache {
EmbeddingDeviceCache(size_t batch_ids_num, size_t cache_vocab_size)
: hash_swap_index_addr_(nullptr), hash_swap_value_addr_(nullptr) {
device_to_host_index = std::make_unique<int[]>(batch_ids_num);
device_to_host_ids = std::make_unique<int[]>(batch_ids_num);
host_to_device_index = std::make_unique<int[]>(batch_ids_num);
host_to_device_ids = std::make_unique<int[]>(batch_ids_num);
device_hash_map_ = std::make_shared<EmbeddingHashMap>(0, cache_vocab_size);
}
std::unique_ptr<int[]> device_to_host_index;
std::unique_ptr<int[]> device_to_host_ids;
std::unique_ptr<int[]> host_to_device_index;
std::unique_ptr<int[]> host_to_device_ids;
int *hash_swap_index_addr_;
float *hash_swap_value_addr_;
std::shared_ptr<EmbeddingHashMap> device_hash_map_;
};
// Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the
// information that needs to be exchanged with the remote cache and device cache. Note that the following information of
// all embedding cache tables on the local host side is same: hash mapping, and feature ids of feature vectors that need
// to be swapped with the remote cache and device cache.
struct EmbeddingHostCache {
EmbeddingHostCache(size_t batch_ids_num, size_t host_cache_vocab_size) {
host_to_server_index = std::make_unique<int[]>(batch_ids_num);
host_to_server_ids = std::make_unique<int[]>(batch_ids_num);
server_to_host_index = std::make_unique<int[]>(batch_ids_num);
server_to_host_ids = std::make_unique<int[]>(batch_ids_num);
host_to_device_index = std::make_unique<int[]>(batch_ids_num);
device_to_host_index = std::make_unique<int[]>(batch_ids_num);
host_hash_map_ = std::make_shared<EmbeddingHashMap>(0, host_cache_vocab_size);
}
std::unique_ptr<int[]> host_to_server_index;
std::unique_ptr<int[]> host_to_server_ids;
std::unique_ptr<int[]> server_to_host_index;
std::unique_ptr<int[]> server_to_host_ids;
std::unique_ptr<int[]> host_to_device_index;
std::unique_ptr<int[]> device_to_host_index;
std::shared_ptr<EmbeddingHashMap> host_hash_map_;
};
struct EmbeddingCacheStatisticsInfo {
size_t batch_id_count_{0};
size_t batch_id_unique_count_{0};
size_t device_to_host_size_{0};
size_t host_to_device_size_{0};
size_t host_to_server_size_{0};
size_t server_to_host_size_{0};
size_t hash_hit_count_{0};
size_t mem_cache_swap_out_size_{0};
size_t mem_cache_swap_in_size_{0};
size_t mem_cache_hit_count_{0};
};
// The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device
// cache size, host cache size, etc., and can allocate memory for the embedding cache table.
class EmbeddingCacheTableManager {
public:
static EmbeddingCacheTableManager &GetInstance();
// Initialize the EmbeddingCacheTableManager.
void Initialize();
// Finalize the EmbeddingCacheTableManager and release all resource.
void Finalize();
// Insert and save dimension information of the embedding cache table.
void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size, int32_t param_key);
// Parameter will modify the name. After modification, you need to re-insert all the dimension information that saves
// the parameter.
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size);
// Clone a hash table, such as the optimizer's state parameters are generally cloned from weight.
void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name,
int32_t src_param_key);
// Alloc device memory for all embedding cache table.
void AllocMemForEmbeddingCacheTable(const device::DeviceContext *device_context);
// Qeury device address of a embedding cache table.
const Address &QueryHashTableAddr(const std::string &param_name) const;
// Qeury device cache size of a embedding cache table.
size_t QueryHashTableSize(const std::string &param_name) const;
// Check whether a parameter is cache enabled embedding table.
bool IsEmbeddingCacheTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
// Set ids number of a batchsize.
void set_batch_ids_num(size_t batch_ids_num) { batch_ids_num_ = batch_ids_num; }
// Get the offset of the id range corresponding to the embedding cache table slice on each worker in a multi-worker
// automatic parallel scenario.
int cache_indices_lower_bound() const;
void DumpHashTables() const;
private:
EmbeddingCacheTableManager() = default;
~EmbeddingCacheTableManager() = default;
DISABLE_COPY_AND_ASSIGN(EmbeddingCacheTableManager);
// Get embedding table slice bound info on each worker in a multi-worker automatic parallel scenario.
void GetEmbeddingTableSliceBound();
// The hash tables records information such as the dimension, memory address, and cache size of the embedding table
// with the embedding cache enabled.
std::map<std::string, HashTableInfo> hash_tables_;
// Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the
// ids information that needs to be exchanged with the local host cache.
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
// Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the
// information that needs to be exchanged with the remote cache and device cache.
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
// Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
// corresponding to the embedding table slice of the process.
std::pair<int, int> local_embedding_slice_bounds_;
// Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache
// range corresponding to the embedding table slice of the process.
std::pair<int, int> local_device_cache_bounds_;
// Full Embedding table row num, not less than the total number of feature ids.
size_t vocab_size_{0};
// Embedding cache size(row number of embedding cache) of device cache.
size_t device_cache_size_{0};
// Embedding cache size(row number of embedding cache) of local host cache.
size_t host_cache_size_{0};
// Total ids number of a batchsize.
size_t batch_ids_num_{0};
friend class mindspore::runtime::EmbeddingCachePrefetchActor;
};
} // namespace distributed
static distributed::EmbeddingCacheTableManager &embedding_cache_table_manager =
distributed::EmbeddingCacheTableManager::GetInstance();
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_

View File

@ -0,0 +1,107 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "distributed/embedding_cache/embedding_hash_map.h"
namespace mindspore {
namespace distributed {
int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids,
const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size,
bool *const need_wait_graph) {
MS_EXCEPTION_IF_NULL(swap_out_index);
MS_EXCEPTION_IF_NULL(swap_out_ids);
MS_EXCEPTION_IF_NULL(swap_out_size);
bool need_swap = false;
auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph);
if (hash_index == INVALID_INDEX_VALUE) {
return hash_index;
}
if (!need_swap) {
hash_count_++;
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_elements_[hash_index].set_id(id);
hash_map_elements_[hash_index].set_step(data_step);
return hash_index;
}
swap_out_index[*swap_out_size] = hash_index;
swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
(*swap_out_size)++;
(void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_elements_[hash_index].set_id(id);
hash_map_elements_[hash_index].set_step(data_step);
return hash_index;
}
int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap,
bool *const need_wait_graph) {
MS_EXCEPTION_IF_NULL(need_swap);
MS_EXCEPTION_IF_NULL(need_wait_graph);
int hash_index = INVALID_INDEX_VALUE;
while (!expired_element_full_) {
if (hash_map_elements_[current_pos_].IsEmpty()) {
hash_index = current_pos_;
} else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) {
hash_index = current_pos_;
*need_swap = true;
} else if (hash_map_elements_[current_pos_].StepEqual(graph_running_step)) {
graph_running_index_[graph_running_index_num_++] = current_pos_;
}
current_pos_ = (current_pos_ + 1) % hash_capacity_;
if (hash_index != INVALID_INDEX_VALUE) {
return hash_index;
}
if (current_pos_ == current_batch_start_pos_) {
expired_element_full_ = true;
MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_
<< ") will be used, index swap will wait until the graph completed.";
}
}
if (graph_running_index_pos_ != graph_running_index_num_) {
*need_swap = true;
*need_wait_graph = true;
return graph_running_index_[graph_running_index_pos_++];
}
return INVALID_INDEX_VALUE;
}
void EmbeddingHashMap::DumpHashMap() {
MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_;
MS_LOG(INFO) << "Dump hash_id_to_index: ";
for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) {
MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
}
MS_LOG(INFO) << "Dump hash_map_unit: ";
for (size_t i = 0; i < hash_map_elements_.size(); i++) {
if (!hash_map_elements_[i].IsEmpty()) {
MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_
<< " step: " << hash_map_elements_[i].step_;
}
}
MS_LOG(INFO) << "Dump hash map info end.";
}
void EmbeddingHashMap::Reset() {
current_batch_start_pos_ = current_pos_;
graph_running_index_num_ = 0;
graph_running_index_pos_ = 0;
expired_element_full_ = false;
}
} // namespace distributed
} // namespace mindspore

View File

@ -0,0 +1,129 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
#define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
#include <math.h>
#include <utility>
#include <memory>
#include <vector>
#include "utils/hash_map.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace distributed {
// Define the value of an invalid step.
static constexpr size_t INVALID_STEP_VALUE = 0;
// Define the value of an invalid index.
static constexpr int INVALID_INDEX_VALUE = -1;
struct HashMapElement {
int id_{INVALID_INDEX_VALUE};
// The current global step of cache prefetching operation.
size_t step_{INVALID_STEP_VALUE};
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
bool StepEqual(size_t step) const { return step_ == step; }
void set_id(int id) { id_ = id; }
void set_step(size_t step) { step_ = step; }
};
// EmbeddingHashMap is used to manage the id -> index mapping of the embedding cache table on the host
// side. The cache content can be stored on the device or host side.
class EmbeddingHashMap {
public:
EmbeddingHashMap(size_t hash_count, size_t hash_capacity)
: hash_count_(hash_count),
hash_capacity_(hash_capacity),
current_pos_(0),
current_batch_start_pos_(0),
graph_running_index_num_(0),
graph_running_index_pos_(0),
expired_element_full_(false) {
hash_map_elements_.resize(hash_capacity);
// In multi-device mode, embedding table are distributed on different devices by id interval,
// and ids outside the range of local device will use the front and back positions of the table,
// the positions are reserved for this.
hash_map_elements_.front().set_step(SIZE_MAX);
hash_map_elements_.back().set_step(SIZE_MAX);
graph_running_index_ = std::make_unique<int[]>(hash_capacity);
}
~EmbeddingHashMap() = default;
// Find the insertion position (index) in the hash map for an id.
// If the hash map capacity is insufficient, return the information of ids and indices that need to be swapped.
int ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, const size_t data_step,
const size_t graph_running_step, size_t *const swap_out_size, bool *const need_wait_graph);
// Get the global step of a element in hash map.
size_t hash_step(const int hash_index) const { return hash_map_elements_[IntToSize(hash_index)].step_; }
// Set the global step of a element in hash map.
void set_hash_step(const int hash_index, const size_t step) {
hash_map_elements_[IntToSize(hash_index)].set_step(step);
}
// Get the id -> index mapping.
const mindspore::HashMap<int, int> &hash_id_to_index() const { return hash_id_to_index_; }
// Get capacity of hash map.
size_t hash_capacity() const { return hash_capacity_; }
// Reset the hash map.
void Reset();
void DumpHashMap();
private:
// Find the insertion position (index) in the hash map for an id.
int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *const need_swap,
bool *const need_wait_graph);
// Statistics on the usage of hash map capacity.
size_t hash_count_;
// The hash map capacity.
size_t hash_capacity_;
// Record all elements in this hash map.
std::vector<HashMapElement> hash_map_elements_;
// The id -> index mapping.
mindspore::HashMap<int, int> hash_id_to_index_;
// The cursor that records the current slot.
size_t current_pos_;
// The cursor that records the start position of current_pos_.
size_t current_batch_start_pos_;
// The number of ids which need to wait for the calculation graph to finish executing the current step and need be
// swapped out.
size_t graph_running_index_num_;
// The index in array 'graph_running_index_', and the value on this index is the hash index for new id,
// but need to wait for the calculation graph to finish executing the current step and swap out the expired data.
size_t graph_running_index_pos_;
// Record the index information of the feature id that needs to be swapped out after the calculation graph finishes
// executing the current step.
std::unique_ptr<int[]> graph_running_index_;
// The flag indicates hash map is full.
bool expired_element_full_;
};
} // namespace distributed
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_