support ps cache data process thread exit

This commit is contained in:
limingqi107 2020-12-19 10:55:46 +08:00
parent d99189f681
commit 01e9ca5922
2 changed files with 91 additions and 65 deletions

View File

@ -168,7 +168,7 @@ void PsCacheManager::AddEmbeddingTable() const {
}
void PsCacheManager::InitParameterServer() {
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_;
MS_LOG(INFO) << "PS embedding cache table init begin:" << finish_insert_init_info_;
std::unique_lock<std::mutex> locker(data_mutex_);
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; });
if (!running_) {
@ -197,7 +197,20 @@ void PsCacheManager::InitParameterServer() {
finish_init_parameter_server_ = true;
data_prase_.notify_one();
MS_LOG(INFO) << "Embedding table init end.";
MS_LOG(INFO) << "PS embedding cache table init end.";
}
void PsCacheManager::InitDataChannel() {
MS_LOG(INFO) << "PS embedding cache data channel init begin.";
auto channel = channel_name();
if (channel.empty()) {
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
if (!running_) {
return;
}
}
MS_LOG(INFO) << "PS embedding cache data channel init end.";
}
void PsCacheManager::AllocMemForHashTable() {
@ -270,8 +283,8 @@ bool PsCacheManager::IncreaseStep() {
}
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (terminated_) {
MS_LOG(EXCEPTION) << "ps cache data process thread is terminated.";
if (!running_) {
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
}
if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
@ -279,7 +292,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (graph_step_ == 0) {
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
data_prase_.wait(locker, [this] { return ((finish_init_parameter_server_ == true) || (running_ == false)); });
if (!running_) {
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
}
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
}
graph_step_++;
@ -300,25 +316,21 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
}
void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
embedding_device_cache_->cache_->InitDevice(device_id, context);
MS_LOG(INFO) << "PS embedding cache process data task begin.";
running_ = true;
bool ret = true;
embedding_device_cache_->cache_->InitDevice(device_id, context);
InitParameterServer();
while (ret) {
if (!running_) {
break;
InitDataChannel();
while (running_) {
if (!ProcessData()) {
running_ = false;
}
ret = ProcessData();
}
if (!ret) {
terminated_ = true;
}
MS_LOG(INFO) << "PS embedding cache process data task end.";
}
void PsCacheManager::Finalize() {
if (running_) {
running_ = false;
}
running_ = false;
PsDataPrefetch::GetInstance().NotifyFinalize();
insert_init_info_.notify_all();
data_prase_.notify_all();
@ -331,14 +343,6 @@ bool PsCacheManager::ProcessData() {
struct timeval start_time, end_time;
const uint64_t kUSecondInSecond = 1000000;
(void)gettimeofday(&start_time, nullptr);
auto channel = channel_name();
if (channel.empty()) {
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
if (!running_) {
return false;
}
}
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
if (data == nullptr) {
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
@ -361,6 +365,7 @@ bool PsCacheManager::ProcessData() {
}
// Get hash swap in/out index and ids.
RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
DumpStatisticsInfo();
for (const auto &item : hash_tables_) {
auto key = worker.GetParamKey(item.first);
auto hash_info = item.second;
@ -389,6 +394,7 @@ bool PsCacheManager::ProcessData() {
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
statistics_info_.batch_id_count_ = batch_ids_len;
for (size_t i = 0; i < batch_ids_len; i++) {
bool need_swap_host_to_device = true;
bool need_swap_device_to_host = true;
@ -397,10 +403,8 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
hash_index[i] = -1;
continue;
}
auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device);
if (index == INVALID_INDEX_VALUE) {
return false;
}
int index = INVALID_INDEX_VALUE;
RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index));
hash_index[i] = index;
if (need_swap_host_to_device) {
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
@ -409,12 +413,6 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id));
}
}
// Each 1000 step prints ps cache hit rate.
if (data_step_ % 1000 == 0) {
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%.";
}
return true;
}
@ -430,14 +428,16 @@ bool PsCacheManager::WaitGraphRun() {
return true;
}
int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) {
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device,
int *hash_index) {
MS_ERROR_IF_NULL(need_swap_device_to_host);
MS_ERROR_IF_NULL(need_swap_host_to_device);
MS_ERROR_IF_NULL(hash_index);
MS_ERROR_IF_NULL(embedding_device_cache_);
auto device_hash_map = embedding_device_cache_->device_hash_map_;
int index = 0;
MS_ERROR_IF_NULL(device_hash_map);
int index = INVALID_INDEX_VALUE;
auto iter = device_hash_map->id_iter(id);
if (device_hash_map->IsIdExist(iter)) {
*need_swap_device_to_host = false;
@ -448,13 +448,19 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b
device_hash_map->set_hash_step(index, data_step_);
}
} else {
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
MS_ERROR_IF_NULL(host_to_device_index);
MS_ERROR_IF_NULL(host_to_device_ids);
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
while (true) {
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
&(statistics_info_.device_to_host_size_));
if (index == INVALID_INDEX_VALUE) {
if (!WaitGraphRun()) {
return INVALID_INDEX_VALUE;
return false;
}
continue;
}
@ -465,23 +471,17 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b
break;
}
}
return index;
*hash_index = index;
return true;
}
bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
MS_ERROR_IF_NULL(embedding_host_cache_);
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
MS_ERROR_IF_NULL(host_to_server_index);
MS_ERROR_IF_NULL(host_to_server_ids);
MS_ERROR_IF_NULL(server_to_host_index);
MS_ERROR_IF_NULL(server_to_host_ids);
MS_ERROR_IF_NULL(host_to_device_index);
auto host_hash_map = embedding_host_cache_->host_hash_map_;
MS_ERROR_IF_NULL(host_hash_map);
auto iter = host_hash_map->id_iter(id);
if (host_hash_map->IsIdExist(iter)) {
auto index = iter->second;
@ -490,6 +490,12 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
}
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
} else {
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
MS_ERROR_IF_NULL(server_to_host_index);
MS_ERROR_IF_NULL(server_to_host_ids);
while (true) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
@ -507,13 +513,10 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
}
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
MS_ERROR_IF_NULL(embedding_device_cache_);
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
MS_ERROR_IF_NULL(device_to_host_ids);
MS_ERROR_IF_NULL(host_to_server_index);
MS_ERROR_IF_NULL(host_to_server_ids);
MS_ERROR_IF_NULL(device_to_host_index);
auto host_hash_map = embedding_host_cache_->host_hash_map_;
@ -527,6 +530,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
}
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
} else {
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
while (true) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
@ -552,13 +557,13 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
if (ret != EOK) {
MS_LOG(ERROR) << "LookUpTable task memcpy failed.";
terminated_ = true;
running_ = false;
}
} else {
auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
if (ret != EOK) {
MS_LOG(ERROR) << "LookUpTable task memset failed.";
terminated_ = true;
running_ = false;
}
}
output_addr += outer_dim_size;
@ -592,7 +597,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
return !terminated_;
return running_;
}
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
@ -615,7 +620,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens);
if (ret != EOK) {
MS_LOG(ERROR) << "Insert hash table task memcpy failed.";
terminated_ = true;
running_ = false;
}
}
}
@ -637,7 +642,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
return !terminated_;
return running_;
}
bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
@ -862,5 +867,25 @@ void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
}
}
}
void PsCacheManager::DumpStatisticsInfo(size_t each_print_step) {
// Default each 1000 step prints ps cache hit rate.
if (data_step_ % each_print_step == 0) {
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
auto repeat_rate = SizeToFloat(statistics_info_.batch_id_count_ - statistics_info_.batch_id_unique_count_) /
statistics_info_.batch_id_count_;
auto device_hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
auto host_hit_rate = SizeToFloat(statistics_info_.batch_id_unique_count_ - statistics_info_.server_to_host_size_) /
statistics_info_.batch_id_unique_count_;
MS_LOG(INFO) << "PS embedding cache data statistics info(total id num:" << statistics_info_.batch_id_count_
<< ", unique id num:" << statistics_info_.batch_id_unique_count_
<< ", host swap to device num:" << statistics_info_.host_to_device_size_
<< ", device swap to host num:" << statistics_info_.device_to_host_size_
<< ", host swap to server num:" << statistics_info_.host_to_server_size_
<< ", server swap to host num:" << statistics_info_.server_to_host_size_
<< ", data repeat rate:" << repeat_rate * 100 << "%, device cache hit rate:" << device_hit_rate * 100
<< "%, host cache hit rate:" << host_hit_rate * 100 << ").";
}
}
} // namespace ps
} // namespace mindspore

View File

@ -94,6 +94,7 @@ struct EmbeddingHostCache {
};
struct PsCacheStatisticsInfo {
size_t batch_id_count_{0};
size_t batch_id_unique_count_{0};
size_t device_to_host_size_{0};
size_t host_to_device_size_{0};
@ -126,7 +127,6 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
bool terminated() const { return terminated_; }
void Finalize();
void DumpHashTables(bool dump_device_tables = false) const;
@ -140,13 +140,14 @@ class PsCacheManager {
std::string channel_name();
void set_channel_name(const std::string channel_name);
void InitParameterServer();
void InitDataChannel();
void AllocMemForHashTable();
void SetLocalIdRank();
void ProcessDataTask(uint32_t device_id, void *context);
bool ProcessData();
bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
bool WaitGraphRun();
int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device);
bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
bool ParseHostDataHostToDevice(size_t id);
bool ParseHostDataDeviceToHost(size_t id);
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
@ -164,6 +165,7 @@ class PsCacheManager {
const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const;
void AddEmbeddingTable() const;
void DumpStatisticsInfo(size_t each_print_step = 1000);
bool initialized_ps_cache_{false};
std::string channel_name_;
@ -189,7 +191,6 @@ class PsCacheManager {
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};
std::atomic_bool terminated_{false};
};
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();