forked from mindspore-Ecosystem/mindspore
!35432 Move embedding cache data structure to distributed directory
Merge pull request !35432 from zyli2020/embedding_cache_unify_runtime
This commit is contained in:
commit
e1f7ed5df7
|
@ -0,0 +1,197 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "distributed/embedding_cache/embedding_cache_utils.h"
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#endif
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
EmbeddingCacheTableManager &EmbeddingCacheTableManager::GetInstance() {
|
||||
static EmbeddingCacheTableManager instance{};
|
||||
return instance;
|
||||
}
|
||||
|
||||
void EmbeddingCacheTableManager::Initialize() { GetEmbeddingTableSliceBound(); }
|
||||
|
||||
void EmbeddingCacheTableManager::Finalize() {
|
||||
hash_tables_.clear();
|
||||
|
||||
embedding_device_cache_ = nullptr;
|
||||
embedding_host_cache_ = nullptr;
|
||||
}
|
||||
|
||||
void EmbeddingCacheTableManager::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size,
|
||||
size_t embedding_size, size_t vocab_size, int32_t param_key) {
|
||||
if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
|
||||
}
|
||||
hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
|
||||
hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
|
||||
hash_tables_[param_name].embedding_size = embedding_size;
|
||||
hash_tables_[param_name].vocab_size = vocab_size;
|
||||
hash_tables_[param_name].param_key_ = param_key;
|
||||
|
||||
if (vocab_size_ == 0) {
|
||||
vocab_size_ = vocab_size;
|
||||
}
|
||||
if (device_cache_size_ == 0) {
|
||||
device_cache_size_ = cache_vocab_size;
|
||||
}
|
||||
if (host_cache_size_ == 0) {
|
||||
host_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
|
||||
}
|
||||
}
|
||||
|
||||
void EmbeddingCacheTableManager::ReInsertHashTableSize(const std::string &new_param_name,
|
||||
const std::string &cur_param_name, size_t cache_vocab_size,
|
||||
size_t embedding_size) {
|
||||
if (cache_vocab_size == 0 || embedding_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
|
||||
}
|
||||
if (new_param_name.empty() || cur_param_name.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
|
||||
}
|
||||
if (new_param_name == cur_param_name) {
|
||||
return;
|
||||
}
|
||||
auto iter = hash_tables_.find(cur_param_name);
|
||||
if (iter != hash_tables_.end()) {
|
||||
hash_tables_.emplace(new_param_name, iter->second);
|
||||
hash_tables_.erase(iter);
|
||||
} else {
|
||||
hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size;
|
||||
hash_tables_[new_param_name].embedding_size = embedding_size;
|
||||
}
|
||||
}
|
||||
|
||||
void EmbeddingCacheTableManager::CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key,
|
||||
const std::string &src_param_name, int32_t src_param_key) {
|
||||
if (dest_param_name == src_param_name) {
|
||||
MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
|
||||
return;
|
||||
}
|
||||
auto iter = hash_tables_.find(src_param_name);
|
||||
if (iter == hash_tables_.end()) {
|
||||
MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
|
||||
}
|
||||
hash_tables_.emplace(dest_param_name, iter->second);
|
||||
hash_tables_[src_param_name].param_key_ = src_param_key;
|
||||
hash_tables_[dest_param_name].param_key_ = dest_param_key;
|
||||
}
|
||||
|
||||
const Address &EmbeddingCacheTableManager::QueryHashTableAddr(const std::string ¶m_name) const {
|
||||
auto iter = hash_tables_.find(param_name);
|
||||
if (iter == hash_tables_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
|
||||
}
|
||||
return iter->second.device_address;
|
||||
}
|
||||
|
||||
size_t EmbeddingCacheTableManager::QueryHashTableSize(const std::string ¶m_name) const {
|
||||
auto iter = hash_tables_.find(param_name);
|
||||
if (iter == hash_tables_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
|
||||
}
|
||||
return iter->second.cache_vocab_size;
|
||||
}
|
||||
|
||||
void EmbeddingCacheTableManager::AllocMemForEmbeddingCacheTable(const device::DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
size_t max_embedding_size = 0;
|
||||
for (auto &item : hash_tables_) {
|
||||
size_t embedding_size = item.second.embedding_size;
|
||||
auto &device_address = item.second.device_address;
|
||||
device_address.size = device_cache_size_ * embedding_size * sizeof(float);
|
||||
auto addr = device_context->AllocateMemory(device_address.size);
|
||||
MS_EXCEPTION_IF_NULL(addr);
|
||||
device_address.addr = addr;
|
||||
|
||||
auto &host_address = item.second.host_address;
|
||||
auto host_hash_table_addr = std::make_unique<float[]>(host_cache_size_ * embedding_size);
|
||||
MS_EXCEPTION_IF_NULL(host_hash_table_addr);
|
||||
host_address = std::shared_ptr<float>(host_hash_table_addr.release(), std::default_delete<float[]>());
|
||||
MS_EXCEPTION_IF_NULL(host_address);
|
||||
|
||||
max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
|
||||
}
|
||||
|
||||
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_ids_num_, device_cache_size_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_ids_num_, host_cache_size_);
|
||||
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
|
||||
|
||||
embedding_device_cache_->hash_swap_index_addr_ =
|
||||
reinterpret_cast<int *>(device_context->AllocateMemory(batch_ids_num_ * sizeof(int)));
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_);
|
||||
embedding_device_cache_->hash_swap_value_addr_ =
|
||||
reinterpret_cast<float *>(device_context->AllocateMemory(max_embedding_size * batch_ids_num_ * sizeof(float)));
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
|
||||
}
|
||||
|
||||
void EmbeddingCacheTableManager::GetEmbeddingTableSliceBound() {
|
||||
auto worker_num = ps::PSContext::instance()->worker_num();
|
||||
if (worker_num == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t rank_id = 0;
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
rank_id = node->rank_id();
|
||||
#endif
|
||||
|
||||
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
|
||||
local_embedding_slice_bounds_.first = local_shard_size * UintToInt(rank_id);
|
||||
local_embedding_slice_bounds_.second =
|
||||
std::min(local_embedding_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
|
||||
local_device_cache_bounds_.first = SizeToInt(device_cache_size_) * rank_id;
|
||||
local_device_cache_bounds_.second = local_device_cache_bounds_.first + SizeToInt(device_cache_size_);
|
||||
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id
|
||||
<< ", id begin:" << local_embedding_slice_bounds_.first
|
||||
<< ", id end:" << local_embedding_slice_bounds_.second
|
||||
<< ", cache indices begin: " << local_device_cache_bounds_.first
|
||||
<< ", cache indices end: " << local_device_cache_bounds_.second;
|
||||
}
|
||||
|
||||
int EmbeddingCacheTableManager::cache_indices_lower_bound() const { return local_device_cache_bounds_.first; }
|
||||
|
||||
void EmbeddingCacheTableManager::DumpHashTables() const {
|
||||
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
|
||||
for (const auto &item : hash_tables_) {
|
||||
const auto ¶m_name = item.first;
|
||||
size_t cache_vocab_size = item.second.cache_vocab_size;
|
||||
size_t host_cache_vocab_size = item.second.host_cache_vocab_size;
|
||||
size_t embedding_size = item.second.embedding_size;
|
||||
size_t vocab_size = item.second.vocab_size;
|
||||
int32_t param_key = item.second.param_key_;
|
||||
MS_LOG(INFO) << "Hash table info:"
|
||||
<< " param_key:" << param_key << ", embedding table name:" << param_name
|
||||
<< ", vocab size:" << vocab_size << ", embedding size:" << embedding_size
|
||||
<< ", device cache size:" << cache_vocab_size << ", host cache size:" << host_cache_vocab_size
|
||||
<< ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr)
|
||||
<< ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get());
|
||||
}
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,203 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "kernel/kernel.h"
|
||||
#include "distributed/embedding_cache/embedding_hash_map.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
class EmbeddingCachePrefetchActor;
|
||||
} // namespace runtime
|
||||
|
||||
namespace distributed {
|
||||
// The local host cache size defaults to 10 times the device cache size.
|
||||
static constexpr size_t kHostCacheScaleFactor = 10;
|
||||
// The maximum number of concurrent threads for data prefetching.
|
||||
static constexpr size_t kMaxThreadNum = 16;
|
||||
// Maximum number of feature ids processed per thread.
|
||||
static constexpr size_t kMaxIdsPerThread = 10000;
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
|
||||
// The hash tables records information such as the dimension, memory address, and cache size of the embedding table
|
||||
// with the embedding cache enabled.
|
||||
struct HashTableInfo {
|
||||
size_t cache_vocab_size{0};
|
||||
size_t host_cache_vocab_size{0};
|
||||
size_t embedding_size{0};
|
||||
size_t vocab_size{0};
|
||||
Address device_address{nullptr, 0};
|
||||
std::shared_ptr<float> host_address{nullptr};
|
||||
int32_t param_key_{-1};
|
||||
};
|
||||
|
||||
// Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the
|
||||
// ids information that needs to be exchanged with the local host cache. Note that the following information of
|
||||
// all embedding cache tables on the device side is same: hash mapping, and feature ids of feature vectors that need
|
||||
// to be swapped with the local host cache.
|
||||
struct EmbeddingDeviceCache {
|
||||
EmbeddingDeviceCache(size_t batch_ids_num, size_t cache_vocab_size)
|
||||
: hash_swap_index_addr_(nullptr), hash_swap_value_addr_(nullptr) {
|
||||
device_to_host_index = std::make_unique<int[]>(batch_ids_num);
|
||||
device_to_host_ids = std::make_unique<int[]>(batch_ids_num);
|
||||
host_to_device_index = std::make_unique<int[]>(batch_ids_num);
|
||||
host_to_device_ids = std::make_unique<int[]>(batch_ids_num);
|
||||
device_hash_map_ = std::make_shared<EmbeddingHashMap>(0, cache_vocab_size);
|
||||
}
|
||||
|
||||
std::unique_ptr<int[]> device_to_host_index;
|
||||
std::unique_ptr<int[]> device_to_host_ids;
|
||||
std::unique_ptr<int[]> host_to_device_index;
|
||||
std::unique_ptr<int[]> host_to_device_ids;
|
||||
int *hash_swap_index_addr_;
|
||||
float *hash_swap_value_addr_;
|
||||
std::shared_ptr<EmbeddingHashMap> device_hash_map_;
|
||||
};
|
||||
|
||||
// Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the
|
||||
// information that needs to be exchanged with the remote cache and device cache. Note that the following information of
|
||||
// all embedding cache tables on the local host side is same: hash mapping, and feature ids of feature vectors that need
|
||||
// to be swapped with the remote cache and device cache.
|
||||
struct EmbeddingHostCache {
|
||||
EmbeddingHostCache(size_t batch_ids_num, size_t host_cache_vocab_size) {
|
||||
host_to_server_index = std::make_unique<int[]>(batch_ids_num);
|
||||
host_to_server_ids = std::make_unique<int[]>(batch_ids_num);
|
||||
server_to_host_index = std::make_unique<int[]>(batch_ids_num);
|
||||
server_to_host_ids = std::make_unique<int[]>(batch_ids_num);
|
||||
host_to_device_index = std::make_unique<int[]>(batch_ids_num);
|
||||
device_to_host_index = std::make_unique<int[]>(batch_ids_num);
|
||||
host_hash_map_ = std::make_shared<EmbeddingHashMap>(0, host_cache_vocab_size);
|
||||
}
|
||||
|
||||
std::unique_ptr<int[]> host_to_server_index;
|
||||
std::unique_ptr<int[]> host_to_server_ids;
|
||||
std::unique_ptr<int[]> server_to_host_index;
|
||||
std::unique_ptr<int[]> server_to_host_ids;
|
||||
std::unique_ptr<int[]> host_to_device_index;
|
||||
std::unique_ptr<int[]> device_to_host_index;
|
||||
std::shared_ptr<EmbeddingHashMap> host_hash_map_;
|
||||
};
|
||||
|
||||
struct EmbeddingCacheStatisticsInfo {
|
||||
size_t batch_id_count_{0};
|
||||
size_t batch_id_unique_count_{0};
|
||||
size_t device_to_host_size_{0};
|
||||
size_t host_to_device_size_{0};
|
||||
size_t host_to_server_size_{0};
|
||||
size_t server_to_host_size_{0};
|
||||
size_t hash_hit_count_{0};
|
||||
size_t mem_cache_swap_out_size_{0};
|
||||
size_t mem_cache_swap_in_size_{0};
|
||||
size_t mem_cache_hit_count_{0};
|
||||
};
|
||||
|
||||
// The EmbeddingCacheTableManager class is used to save all Parameter information for enabling cache, such as device
|
||||
// cache size, host cache size, etc., and can allocate memory for the embedding cache table.
|
||||
class EmbeddingCacheTableManager {
|
||||
public:
|
||||
static EmbeddingCacheTableManager &GetInstance();
|
||||
|
||||
// Initialize the EmbeddingCacheTableManager.
|
||||
void Initialize();
|
||||
// Finalize the EmbeddingCacheTableManager and release all resource.
|
||||
void Finalize();
|
||||
|
||||
// Insert and save dimension information of the embedding cache table.
|
||||
void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size, int32_t param_key);
|
||||
|
||||
// Parameter will modify the name. After modification, you need to re-insert all the dimension information that saves
|
||||
// the parameter.
|
||||
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||
size_t cache_vocab_size, size_t embedding_size);
|
||||
|
||||
// Clone a hash table, such as the optimizer's state parameters are generally cloned from weight.
|
||||
void CloneHashTable(const std::string &dest_param_name, int32_t dest_param_key, const std::string &src_param_name,
|
||||
int32_t src_param_key);
|
||||
|
||||
// Alloc device memory for all embedding cache table.
|
||||
void AllocMemForEmbeddingCacheTable(const device::DeviceContext *device_context);
|
||||
|
||||
// Qeury device address of a embedding cache table.
|
||||
const Address &QueryHashTableAddr(const std::string ¶m_name) const;
|
||||
|
||||
// Qeury device cache size of a embedding cache table.
|
||||
size_t QueryHashTableSize(const std::string ¶m_name) const;
|
||||
|
||||
// Check whether a parameter is cache enabled embedding table.
|
||||
bool IsEmbeddingCacheTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; }
|
||||
|
||||
// Set ids number of a batchsize.
|
||||
void set_batch_ids_num(size_t batch_ids_num) { batch_ids_num_ = batch_ids_num; }
|
||||
|
||||
// Get the offset of the id range corresponding to the embedding cache table slice on each worker in a multi-worker
|
||||
// automatic parallel scenario.
|
||||
int cache_indices_lower_bound() const;
|
||||
|
||||
void DumpHashTables() const;
|
||||
|
||||
private:
|
||||
EmbeddingCacheTableManager() = default;
|
||||
~EmbeddingCacheTableManager() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(EmbeddingCacheTableManager);
|
||||
|
||||
// Get embedding table slice bound info on each worker in a multi-worker automatic parallel scenario.
|
||||
void GetEmbeddingTableSliceBound();
|
||||
|
||||
// The hash tables records information such as the dimension, memory address, and cache size of the embedding table
|
||||
// with the embedding cache enabled.
|
||||
std::map<std::string, HashTableInfo> hash_tables_;
|
||||
|
||||
// Record the hash mapping relationship of all embedding tables with cache enabled on the device side, and the
|
||||
// ids information that needs to be exchanged with the local host cache.
|
||||
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
|
||||
|
||||
// Record the hash mapping relationship of all embedding tables with cache enabled on the local host side, and the
|
||||
// information that needs to be exchanged with the remote cache and device cache.
|
||||
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
|
||||
|
||||
// Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
|
||||
// corresponding to the embedding table slice of the process.
|
||||
std::pair<int, int> local_embedding_slice_bounds_;
|
||||
|
||||
// Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache
|
||||
// range corresponding to the embedding table slice of the process.
|
||||
std::pair<int, int> local_device_cache_bounds_;
|
||||
|
||||
// Full Embedding table row num, not less than the total number of feature ids.
|
||||
size_t vocab_size_{0};
|
||||
// Embedding cache size(row number of embedding cache) of device cache.
|
||||
size_t device_cache_size_{0};
|
||||
// Embedding cache size(row number of embedding cache) of local host cache.
|
||||
size_t host_cache_size_{0};
|
||||
// Total ids number of a batchsize.
|
||||
size_t batch_ids_num_{0};
|
||||
|
||||
friend class mindspore::runtime::EmbeddingCachePrefetchActor;
|
||||
};
|
||||
} // namespace distributed
|
||||
static distributed::EmbeddingCacheTableManager &embedding_cache_table_manager =
|
||||
distributed::EmbeddingCacheTableManager::GetInstance();
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_CHCHE_UTILS_H_
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "distributed/embedding_cache/embedding_hash_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
int EmbeddingHashMap::ParseData(const int id, int *const swap_out_index, int *const swap_out_ids,
|
||||
const size_t data_step, const size_t graph_running_step, size_t *const swap_out_size,
|
||||
bool *const need_wait_graph) {
|
||||
MS_EXCEPTION_IF_NULL(swap_out_index);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_ids);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_size);
|
||||
bool need_swap = false;
|
||||
auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph);
|
||||
if (hash_index == INVALID_INDEX_VALUE) {
|
||||
return hash_index;
|
||||
}
|
||||
|
||||
if (!need_swap) {
|
||||
hash_count_++;
|
||||
(void)hash_id_to_index_.emplace(id, hash_index);
|
||||
hash_map_elements_[hash_index].set_id(id);
|
||||
hash_map_elements_[hash_index].set_step(data_step);
|
||||
return hash_index;
|
||||
}
|
||||
|
||||
swap_out_index[*swap_out_size] = hash_index;
|
||||
swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
|
||||
(*swap_out_size)++;
|
||||
(void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
|
||||
(void)hash_id_to_index_.emplace(id, hash_index);
|
||||
hash_map_elements_[hash_index].set_id(id);
|
||||
hash_map_elements_[hash_index].set_step(data_step);
|
||||
return hash_index;
|
||||
}
|
||||
|
||||
int EmbeddingHashMap::FindInsertionPos(const size_t, const size_t graph_running_step, bool *const need_swap,
|
||||
bool *const need_wait_graph) {
|
||||
MS_EXCEPTION_IF_NULL(need_swap);
|
||||
MS_EXCEPTION_IF_NULL(need_wait_graph);
|
||||
int hash_index = INVALID_INDEX_VALUE;
|
||||
while (!expired_element_full_) {
|
||||
if (hash_map_elements_[current_pos_].IsEmpty()) {
|
||||
hash_index = current_pos_;
|
||||
} else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) {
|
||||
hash_index = current_pos_;
|
||||
*need_swap = true;
|
||||
} else if (hash_map_elements_[current_pos_].StepEqual(graph_running_step)) {
|
||||
graph_running_index_[graph_running_index_num_++] = current_pos_;
|
||||
}
|
||||
current_pos_ = (current_pos_ + 1) % hash_capacity_;
|
||||
if (hash_index != INVALID_INDEX_VALUE) {
|
||||
return hash_index;
|
||||
}
|
||||
if (current_pos_ == current_batch_start_pos_) {
|
||||
expired_element_full_ = true;
|
||||
MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_
|
||||
<< ") will be used, index swap will wait until the graph completed.";
|
||||
}
|
||||
}
|
||||
|
||||
if (graph_running_index_pos_ != graph_running_index_num_) {
|
||||
*need_swap = true;
|
||||
*need_wait_graph = true;
|
||||
return graph_running_index_[graph_running_index_pos_++];
|
||||
}
|
||||
return INVALID_INDEX_VALUE;
|
||||
}
|
||||
|
||||
void EmbeddingHashMap::DumpHashMap() {
|
||||
MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_;
|
||||
MS_LOG(INFO) << "Dump hash_id_to_index: ";
|
||||
for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) {
|
||||
MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
|
||||
}
|
||||
MS_LOG(INFO) << "Dump hash_map_unit: ";
|
||||
for (size_t i = 0; i < hash_map_elements_.size(); i++) {
|
||||
if (!hash_map_elements_[i].IsEmpty()) {
|
||||
MS_LOG(INFO) << " index: " << i << " id: " << hash_map_elements_[i].id_
|
||||
<< " step: " << hash_map_elements_[i].step_;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Dump hash map info end.";
|
||||
}
|
||||
|
||||
void EmbeddingHashMap::Reset() {
|
||||
current_batch_start_pos_ = current_pos_;
|
||||
graph_running_index_num_ = 0;
|
||||
graph_running_index_pos_ = 0;
|
||||
expired_element_full_ = false;
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
|
||||
#define MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
|
||||
|
||||
#include <math.h>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
// Define the value of an invalid step.
|
||||
static constexpr size_t INVALID_STEP_VALUE = 0;
|
||||
// Define the value of an invalid index.
|
||||
static constexpr int INVALID_INDEX_VALUE = -1;
|
||||
|
||||
struct HashMapElement {
|
||||
int id_{INVALID_INDEX_VALUE};
|
||||
// The current global step of cache prefetching operation.
|
||||
size_t step_{INVALID_STEP_VALUE};
|
||||
|
||||
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
|
||||
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
|
||||
bool StepEqual(size_t step) const { return step_ == step; }
|
||||
void set_id(int id) { id_ = id; }
|
||||
void set_step(size_t step) { step_ = step; }
|
||||
};
|
||||
|
||||
// EmbeddingHashMap is used to manage the id -> index mapping of the embedding cache table on the host
|
||||
// side. The cache content can be stored on the device or host side.
|
||||
class EmbeddingHashMap {
|
||||
public:
|
||||
EmbeddingHashMap(size_t hash_count, size_t hash_capacity)
|
||||
: hash_count_(hash_count),
|
||||
hash_capacity_(hash_capacity),
|
||||
current_pos_(0),
|
||||
current_batch_start_pos_(0),
|
||||
graph_running_index_num_(0),
|
||||
graph_running_index_pos_(0),
|
||||
expired_element_full_(false) {
|
||||
hash_map_elements_.resize(hash_capacity);
|
||||
// In multi-device mode, embedding table are distributed on different devices by id interval,
|
||||
// and ids outside the range of local device will use the front and back positions of the table,
|
||||
// the positions are reserved for this.
|
||||
hash_map_elements_.front().set_step(SIZE_MAX);
|
||||
hash_map_elements_.back().set_step(SIZE_MAX);
|
||||
graph_running_index_ = std::make_unique<int[]>(hash_capacity);
|
||||
}
|
||||
|
||||
~EmbeddingHashMap() = default;
|
||||
|
||||
// Find the insertion position (index) in the hash map for an id.
|
||||
// If the hash map capacity is insufficient, return the information of ids and indices that need to be swapped.
|
||||
int ParseData(const int id, int *const swap_out_index, int *const swap_out_ids, const size_t data_step,
|
||||
const size_t graph_running_step, size_t *const swap_out_size, bool *const need_wait_graph);
|
||||
|
||||
// Get the global step of a element in hash map.
|
||||
size_t hash_step(const int hash_index) const { return hash_map_elements_[IntToSize(hash_index)].step_; }
|
||||
// Set the global step of a element in hash map.
|
||||
void set_hash_step(const int hash_index, const size_t step) {
|
||||
hash_map_elements_[IntToSize(hash_index)].set_step(step);
|
||||
}
|
||||
|
||||
// Get the id -> index mapping.
|
||||
const mindspore::HashMap<int, int> &hash_id_to_index() const { return hash_id_to_index_; }
|
||||
|
||||
// Get capacity of hash map.
|
||||
size_t hash_capacity() const { return hash_capacity_; }
|
||||
|
||||
// Reset the hash map.
|
||||
void Reset();
|
||||
|
||||
void DumpHashMap();
|
||||
|
||||
private:
|
||||
// Find the insertion position (index) in the hash map for an id.
|
||||
int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *const need_swap,
|
||||
bool *const need_wait_graph);
|
||||
|
||||
// Statistics on the usage of hash map capacity.
|
||||
size_t hash_count_;
|
||||
|
||||
// The hash map capacity.
|
||||
size_t hash_capacity_;
|
||||
|
||||
// Record all elements in this hash map.
|
||||
std::vector<HashMapElement> hash_map_elements_;
|
||||
|
||||
// The id -> index mapping.
|
||||
mindspore::HashMap<int, int> hash_id_to_index_;
|
||||
|
||||
// The cursor that records the current slot.
|
||||
size_t current_pos_;
|
||||
// The cursor that records the start position of current_pos_.
|
||||
size_t current_batch_start_pos_;
|
||||
|
||||
// The number of ids which need to wait for the calculation graph to finish executing the current step and need be
|
||||
// swapped out.
|
||||
size_t graph_running_index_num_;
|
||||
// The index in array 'graph_running_index_', and the value on this index is the hash index for new id,
|
||||
// but need to wait for the calculation graph to finish executing the current step and swap out the expired data.
|
||||
size_t graph_running_index_pos_;
|
||||
// Record the index information of the feature id that needs to be swapped out after the calculation graph finishes
|
||||
// executing the current step.
|
||||
std::unique_ptr<int[]> graph_running_index_;
|
||||
|
||||
// The flag indicates hash map is full.
|
||||
bool expired_element_full_;
|
||||
};
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_HASH_MAP_H_
|
Loading…
Reference in New Issue