!9700 add ps cache manager

From: @limingqi107
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-09 23:47:12 +08:00 committed by Gitee
commit e8a442eeb3
4 changed files with 407 additions and 0 deletions

View File

@ -18,15 +18,18 @@ endif ()
if (NOT ENABLE_D)
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
endif()
if (NOT ENABLE_GPU)
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc")
endif()
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc")
add_subdirectory(ps_cache)
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})

View File

@ -0,0 +1,217 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "ps/ps_cache/ps_cache_manager.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
using mindspore::kernel::Address;
namespace mindspore {
namespace ps {
void PsCacheManager::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) {
if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) {
MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
}
hash_tables_[param_name].cache_vocab_size = cache_vocab_size;
hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor;
hash_tables_[param_name].embedding_size = embedding_size;
hash_tables_[param_name].vocab_size = vocab_size;
if (vocab_size_ == 0) {
vocab_size_ = vocab_size;
}
if (cache_vocab_size_ == 0) {
cache_vocab_size_ = cache_vocab_size;
}
if (host_cache_vocab_size_ == 0) {
host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor;
}
}
void PsCacheManager::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size) {
if (cache_vocab_size == 0 || embedding_size == 0) {
MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero.";
}
if (new_param_name.empty() || cur_param_name.empty()) {
MS_LOG(EXCEPTION) << "Parameter name can not be empty.";
}
if (new_param_name == cur_param_name) {
return;
}
auto iter = hash_tables_.find(cur_param_name);
if (iter != hash_tables_.end()) {
hash_tables_.emplace(new_param_name, iter->second);
hash_tables_.erase(iter);
} else {
hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size;
hash_tables_[new_param_name].embedding_size = embedding_size;
}
}
void PsCacheManager::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
}
auto &hash_table_info = iter->second;
hash_table_info.param_init_info_.param_type_ = kWeight;
hash_table_info.param_init_info_.global_seed_ = global_seed;
hash_table_info.param_init_info_.op_seed_ = op_seed;
}
void PsCacheManager::InsertAccumuInitInfo(const std::string &param_name, float init_val) {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table.";
}
auto &hash_table_info = iter->second;
hash_table_info.param_init_info_.param_type_ = kAccumulation;
hash_table_info.param_init_info_.init_val_ = init_val;
}
void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) {
if (dest_param_name == src_param_name) {
MS_LOG(INFO) << "The dest_param_name is same as src_param_name";
return;
}
auto iter = hash_tables_.find(src_param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed.";
}
hash_tables_.emplace(dest_param_name, iter->second);
}
const Address &PsCacheManager::QueryHashTableAddr(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name;
}
return iter->second.device_address;
}
void PsCacheManager::Initialize() {
MS_LOG(INFO) << "PS cache initialize.";
if (!worker.running()) {
Util::SetInternalEnvVar();
worker.Run();
}
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_);
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_);
InitParameterServer();
AllocMemForHashTable();
SetLocalIdRank();
initialized_ps_cache_ = true;
}
void PsCacheManager::InitParameterServer() {
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
size_t row_count = item.second.vocab_size;
std::vector<size_t> keys{key, key, key, key};
std::vector<float> values{
SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1};
std::vector<int64_t> lens{2, 2, 3};
const auto &hash_table_info = item.second;
const auto &param_init_info = hash_table_info.param_init_info_;
if (param_init_info.param_type_ == kWeight) {
lens.push_back(0);
values.push_back(SizeToFloat(param_init_info.global_seed_));
values.push_back(SizeToFloat(param_init_info.op_seed_));
} else if (param_init_info.param_type_ == kAccumulation) {
lens.push_back(1);
values.push_back(param_init_info.init_val_);
}
// if worker role
worker.AddEmbeddingTable(key, row_count);
worker.InitPSEmbeddingTable(keys, values, lens);
}
}
void PsCacheManager::AllocMemForHashTable() {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
size_t max_embedding_size = 0;
for (auto &item : hash_tables_) {
size_t embedding_size = item.second.embedding_size;
auto &device_address = item.second.device_address;
device_address.size = cache_vocab_size_ * embedding_size * sizeof(float);
auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size);
MS_EXCEPTION_IF_NULL(addr);
device_address.addr = addr;
auto &host_address = item.second.host_address;
auto host_address_ptr = new int[host_cache_vocab_size_ * embedding_size];
MS_EXCEPTION_IF_NULL(host_address_ptr);
host_address = std::shared_ptr<int[]>(host_address_ptr, std::default_delete<int[]>());
MS_EXCEPTION_IF_NULL(host_address);
max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
}
embedding_device_cache_->hash_swap_index_addr_ =
reinterpret_cast<int *>(embedding_device_cache_->cache_->MallocMemory(batch_elements_ * sizeof(int)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_);
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_);
}
void PsCacheManager::SetLocalIdRank() {
auto worker_num = ::ps::NumWorkers();
auto worker_id = ::ps::MyRank();
auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num));
range_bound_.first = local_shard_size * worker_id;
range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_);
MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first
<< ", rank id end:" << range_bound_.second;
}
std::string PsCacheManager::channel_name() {
std::lock_guard<std::mutex> locker(channel_mutex_);
return channel_name_;
}
void PsCacheManager::set_channel_name(const std::string channel_name) {
if (channel_name_ == channel_name) {
return;
}
std::lock_guard<std::mutex> locker(channel_mutex_);
channel_name_ = channel_name;
}
void PsCacheManager::IncreaseStep() {
if (data_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
}
data_step_++;
set_current_graph_step();
}
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
}
graph_step_++;
set_channel_name(channel_name);
PsDataPrefetch::GetInstance().TryWakeChannel(channel_name);
data_prase_.notify_one();
}
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,186 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_
#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_
#include <map>
#include <string>
#include <vector>
#include <thread>
#include <atomic>
#include <utility>
#include <memory>
#include <condition_variable>
#include "utils/ms_context.h"
#include "backend/kernel_compiler/kernel.h"
#include "utils/shape_utils.h"
#include "ir/tensor.h"
#include "ps/ps.h"
#include "ps/common.h"
#include "ps/worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/embedding_hash_map.h"
#include "ps/ps_cache/ps_cache_factory.h"
namespace mindspore {
namespace ps {
constexpr size_t kHostCacheScaleFactor = 10;
constexpr size_t kMaxThreadNum = 16;
using mindspore::kernel::Address;
struct HashTableInfo {
size_t cache_vocab_size{0};
size_t host_cache_vocab_size{0};
size_t embedding_size{0};
size_t vocab_size{0};
Address device_address{nullptr, 0};
std::shared_ptr<int[]> host_address{nullptr};
};
struct EmbeddingDeviceCache {
EmbeddingDeviceCache(size_t batch_elements, size_t cache_vocab_size) {
device_to_host_index = std::make_unique<int[]>(batch_elements);
device_to_host_ids = std::make_unique<int[]>(batch_elements);
host_to_device_index = std::make_unique<int[]>(batch_elements);
host_to_device_ids = std::make_unique<int[]>(batch_elements);
device_hash_map_ = std::make_shared<EmbeddingHashMap>(0, cache_vocab_size);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
cache_ = PsCacheFactory::Get().ps_cache(devcie_target);
}
std::unique_ptr<int[]> device_to_host_index;
std::unique_ptr<int[]> device_to_host_ids;
std::unique_ptr<int[]> host_to_device_index;
std::unique_ptr<int[]> host_to_device_ids;
int *hash_swap_index_addr_;
float *hash_swap_value_addr_;
std::shared_ptr<EmbeddingHashMap> device_hash_map_;
std::shared_ptr<PsCacheBasic> cache_;
};
struct EmbeddingHostCache {
EmbeddingHostCache(size_t batch_elements, size_t host_cache_vocab_size) {
host_to_server_index = std::make_unique<int[]>(batch_elements);
host_to_server_ids = std::make_unique<int[]>(batch_elements);
server_to_host_index = std::make_unique<int[]>(batch_elements);
server_to_host_ids = std::make_unique<int[]>(batch_elements);
host_to_device_index = std::make_unique<int[]>(batch_elements);
device_to_host_index = std::make_unique<int[]>(batch_elements);
host_hash_map_ = std::make_shared<EmbeddingHashMap>(0, host_cache_vocab_size);
}
std::unique_ptr<int[]> host_to_server_index;
std::unique_ptr<int[]> host_to_server_ids;
std::unique_ptr<int[]> server_to_host_index;
std::unique_ptr<int[]> server_to_host_ids;
std::unique_ptr<int[]> host_to_device_index;
std::unique_ptr<int[]> device_to_host_index;
std::shared_ptr<EmbeddingHashMap> host_hash_map_;
};
struct PsCacheStatisticsInfo {
size_t batch_id_unique_count_{0};
size_t device_to_host_size_{0};
size_t host_to_device_size_{0};
size_t host_to_server_size_{0};
size_t server_to_host_size_{0};
size_t hash_hit_count_{0};
size_t mem_cache_swap_out_size_{0};
size_t mem_cache_swap_in_size_{0};
size_t mem_cache_hit_count_{0};
};
class PsCacheManager {
public:
static PsCacheManager &GetInstance() {
static PsCacheManager instance;
return instance;
}
void Initialize();
void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size);
void InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed);
void InsertAccumuInitInfo(const std::string &param_name, float init_val);
void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size);
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name);
const Address &QueryHashTableAddr(const std::string &param_name) const;
bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
void DumpHashTables() const;
private:
PsCacheManager() = default;
~PsCacheManager() = default;
PsCacheManager(const PsCacheManager &) = delete;
PsCacheManager &operator=(const PsCacheManager &) = delete;
void IncreaseStep();
void set_current_graph_step() { graph_running_step_ = graph_step_; }
std::string channel_name();
void set_channel_name(const std::string channel_name);
void InitParameterServer();
void AllocMemForHashTable();
void SetLocalIdRank();
void ProcessDataTask(uint32_t device_id, void *context);
void ProcessData();
void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
void WaitGraphRun();
int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device);
void ParseHostDataHostToDevice(size_t id);
void ParseHostDataDeviceToHost(size_t id);
void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
void HashSwapHostToDevice(const HashTableInfo &hash_info);
void HashSwapDeviceToHost(const HashTableInfo &hash_info);
void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info);
void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info);
void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data,
float *hash_table_addr);
void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr);
void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr);
bool initialized_ps_cache_{false};
std::string channel_name_;
std::mutex channel_mutex_;
std::atomic_ulong graph_step_{0};
size_t graph_running_step_{0};
size_t data_step_{0};
std::mutex data_mutex_;
std::condition_variable data_prase_;
std::map<std::string, HashTableInfo> hash_tables_;
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
size_t vocab_size_{0};
size_t cache_vocab_size_{0};
size_t host_cache_vocab_size_{0};
size_t batch_elements_{0};
PsCacheStatisticsInfo statistics_info_;
std::pair<size_t, size_t> range_bound_;
};
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_

View File

@ -141,6 +141,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")