support ps cache data process thread exit
This commit is contained in:
parent
d99189f681
commit
01e9ca5922
|
@ -168,7 +168,7 @@ void PsCacheManager::AddEmbeddingTable() const {
|
|||
}
|
||||
|
||||
void PsCacheManager::InitParameterServer() {
|
||||
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_;
|
||||
MS_LOG(INFO) << "PS embedding cache table init begin:" << finish_insert_init_info_;
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; });
|
||||
if (!running_) {
|
||||
|
@ -197,7 +197,20 @@ void PsCacheManager::InitParameterServer() {
|
|||
|
||||
finish_init_parameter_server_ = true;
|
||||
data_prase_.notify_one();
|
||||
MS_LOG(INFO) << "Embedding table init end.";
|
||||
MS_LOG(INFO) << "PS embedding cache table init end.";
|
||||
}
|
||||
|
||||
void PsCacheManager::InitDataChannel() {
|
||||
MS_LOG(INFO) << "PS embedding cache data channel init begin.";
|
||||
auto channel = channel_name();
|
||||
if (channel.empty()) {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
|
||||
if (!running_) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "PS embedding cache data channel init end.";
|
||||
}
|
||||
|
||||
void PsCacheManager::AllocMemForHashTable() {
|
||||
|
@ -270,8 +283,8 @@ bool PsCacheManager::IncreaseStep() {
|
|||
}
|
||||
|
||||
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
|
||||
if (terminated_) {
|
||||
MS_LOG(EXCEPTION) << "ps cache data process thread is terminated.";
|
||||
if (!running_) {
|
||||
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
|
||||
}
|
||||
if (graph_step_ >= UINT64_MAX) {
|
||||
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
|
||||
|
@ -279,7 +292,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
|
|||
if (graph_step_ == 0) {
|
||||
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; });
|
||||
data_prase_.wait(locker, [this] { return ((finish_init_parameter_server_ == true) || (running_ == false)); });
|
||||
if (!running_) {
|
||||
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running.";
|
||||
}
|
||||
MS_LOG(INFO) << "Graph running waiting embedding table init end.";
|
||||
}
|
||||
graph_step_++;
|
||||
|
@ -300,25 +316,21 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
|
|||
}
|
||||
|
||||
void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
|
||||
embedding_device_cache_->cache_->InitDevice(device_id, context);
|
||||
MS_LOG(INFO) << "PS embedding cache process data task begin.";
|
||||
running_ = true;
|
||||
bool ret = true;
|
||||
embedding_device_cache_->cache_->InitDevice(device_id, context);
|
||||
InitParameterServer();
|
||||
while (ret) {
|
||||
if (!running_) {
|
||||
break;
|
||||
InitDataChannel();
|
||||
while (running_) {
|
||||
if (!ProcessData()) {
|
||||
running_ = false;
|
||||
}
|
||||
ret = ProcessData();
|
||||
}
|
||||
if (!ret) {
|
||||
terminated_ = true;
|
||||
}
|
||||
MS_LOG(INFO) << "PS embedding cache process data task end.";
|
||||
}
|
||||
|
||||
void PsCacheManager::Finalize() {
|
||||
if (running_) {
|
||||
running_ = false;
|
||||
}
|
||||
running_ = false;
|
||||
PsDataPrefetch::GetInstance().NotifyFinalize();
|
||||
insert_init_info_.notify_all();
|
||||
data_prase_.notify_all();
|
||||
|
@ -331,14 +343,6 @@ bool PsCacheManager::ProcessData() {
|
|||
struct timeval start_time, end_time;
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
auto channel = channel_name();
|
||||
if (channel.empty()) {
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
|
||||
if (!running_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
|
||||
if (data == nullptr) {
|
||||
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
|
||||
|
@ -361,6 +365,7 @@ bool PsCacheManager::ProcessData() {
|
|||
}
|
||||
// Get hash swap in/out index and ids.
|
||||
RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
|
||||
DumpStatisticsInfo();
|
||||
for (const auto &item : hash_tables_) {
|
||||
auto key = worker.GetParamKey(item.first);
|
||||
auto hash_info = item.second;
|
||||
|
@ -389,6 +394,7 @@ bool PsCacheManager::ProcessData() {
|
|||
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
|
||||
MS_ERROR_IF_NULL(batch_ids);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
statistics_info_.batch_id_count_ = batch_ids_len;
|
||||
for (size_t i = 0; i < batch_ids_len; i++) {
|
||||
bool need_swap_host_to_device = true;
|
||||
bool need_swap_device_to_host = true;
|
||||
|
@ -397,10 +403,8 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
|
|||
hash_index[i] = -1;
|
||||
continue;
|
||||
}
|
||||
auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device);
|
||||
if (index == INVALID_INDEX_VALUE) {
|
||||
return false;
|
||||
}
|
||||
int index = INVALID_INDEX_VALUE;
|
||||
RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index));
|
||||
hash_index[i] = index;
|
||||
if (need_swap_host_to_device) {
|
||||
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
|
||||
|
@ -409,12 +413,6 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
|
|||
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id));
|
||||
}
|
||||
}
|
||||
// Each 1000 step prints ps cache hit rate.
|
||||
if (data_step_ % 1000 == 0) {
|
||||
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
|
||||
auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
|
||||
MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -430,14 +428,16 @@ bool PsCacheManager::WaitGraphRun() {
|
|||
return true;
|
||||
}
|
||||
|
||||
int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) {
|
||||
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
|
||||
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
|
||||
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
|
||||
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
|
||||
|
||||
bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device,
|
||||
int *hash_index) {
|
||||
MS_ERROR_IF_NULL(need_swap_device_to_host);
|
||||
MS_ERROR_IF_NULL(need_swap_host_to_device);
|
||||
MS_ERROR_IF_NULL(hash_index);
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
auto device_hash_map = embedding_device_cache_->device_hash_map_;
|
||||
int index = 0;
|
||||
MS_ERROR_IF_NULL(device_hash_map);
|
||||
|
||||
int index = INVALID_INDEX_VALUE;
|
||||
auto iter = device_hash_map->id_iter(id);
|
||||
if (device_hash_map->IsIdExist(iter)) {
|
||||
*need_swap_device_to_host = false;
|
||||
|
@ -448,13 +448,19 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b
|
|||
device_hash_map->set_hash_step(index, data_step_);
|
||||
}
|
||||
} else {
|
||||
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
|
||||
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
|
||||
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
|
||||
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
|
||||
MS_ERROR_IF_NULL(host_to_device_index);
|
||||
MS_ERROR_IF_NULL(host_to_device_ids);
|
||||
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
|
||||
while (true) {
|
||||
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
|
||||
&(statistics_info_.device_to_host_size_));
|
||||
if (index == INVALID_INDEX_VALUE) {
|
||||
if (!WaitGraphRun()) {
|
||||
return INVALID_INDEX_VALUE;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -465,23 +471,17 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b
|
|||
break;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
*hash_index = index;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
||||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
|
||||
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
|
||||
MS_ERROR_IF_NULL(embedding_host_cache_);
|
||||
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
|
||||
MS_ERROR_IF_NULL(host_to_server_index);
|
||||
MS_ERROR_IF_NULL(host_to_server_ids);
|
||||
MS_ERROR_IF_NULL(server_to_host_index);
|
||||
MS_ERROR_IF_NULL(server_to_host_ids);
|
||||
MS_ERROR_IF_NULL(host_to_device_index);
|
||||
|
||||
auto host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
MS_ERROR_IF_NULL(host_hash_map);
|
||||
|
||||
auto iter = host_hash_map->id_iter(id);
|
||||
if (host_hash_map->IsIdExist(iter)) {
|
||||
auto index = iter->second;
|
||||
|
@ -490,6 +490,12 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
|||
}
|
||||
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
|
||||
} else {
|
||||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
|
||||
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
|
||||
MS_ERROR_IF_NULL(server_to_host_index);
|
||||
MS_ERROR_IF_NULL(server_to_host_ids);
|
||||
while (true) {
|
||||
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
|
||||
graph_running_step_, &statistics_info_.host_to_server_size_);
|
||||
|
@ -507,13 +513,10 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
|||
}
|
||||
|
||||
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
|
||||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
|
||||
MS_ERROR_IF_NULL(device_to_host_ids);
|
||||
MS_ERROR_IF_NULL(host_to_server_index);
|
||||
MS_ERROR_IF_NULL(host_to_server_ids);
|
||||
MS_ERROR_IF_NULL(device_to_host_index);
|
||||
|
||||
auto host_hash_map = embedding_host_cache_->host_hash_map_;
|
||||
|
@ -527,6 +530,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
|
|||
}
|
||||
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
|
||||
} else {
|
||||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
while (true) {
|
||||
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
|
||||
graph_running_step_, &statistics_info_.host_to_server_size_);
|
||||
|
@ -552,13 +557,13 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
|
|||
auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "LookUpTable task memcpy failed.";
|
||||
terminated_ = true;
|
||||
running_ = false;
|
||||
}
|
||||
} else {
|
||||
auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "LookUpTable task memset failed.";
|
||||
terminated_ = true;
|
||||
running_ = false;
|
||||
}
|
||||
}
|
||||
output_addr += outer_dim_size;
|
||||
|
@ -592,7 +597,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
|
|||
for (size_t j = 0; j < i; j++) {
|
||||
threads[j].join();
|
||||
}
|
||||
return !terminated_;
|
||||
return running_;
|
||||
}
|
||||
|
||||
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
|
||||
|
@ -615,7 +620,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
|
|||
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Insert hash table task memcpy failed.";
|
||||
terminated_ = true;
|
||||
running_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -637,7 +642,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
|
|||
for (size_t j = 0; j < i; j++) {
|
||||
threads[j].join();
|
||||
}
|
||||
return !terminated_;
|
||||
return running_;
|
||||
}
|
||||
|
||||
bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
|
||||
|
@ -862,5 +867,25 @@ void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PsCacheManager::DumpStatisticsInfo(size_t each_print_step) {
|
||||
// Default each 1000 step prints ps cache hit rate.
|
||||
if (data_step_ % each_print_step == 0) {
|
||||
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_;
|
||||
auto repeat_rate = SizeToFloat(statistics_info_.batch_id_count_ - statistics_info_.batch_id_unique_count_) /
|
||||
statistics_info_.batch_id_count_;
|
||||
auto device_hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
|
||||
auto host_hit_rate = SizeToFloat(statistics_info_.batch_id_unique_count_ - statistics_info_.server_to_host_size_) /
|
||||
statistics_info_.batch_id_unique_count_;
|
||||
MS_LOG(INFO) << "PS embedding cache data statistics info(total id num:" << statistics_info_.batch_id_count_
|
||||
<< ", unique id num:" << statistics_info_.batch_id_unique_count_
|
||||
<< ", host swap to device num:" << statistics_info_.host_to_device_size_
|
||||
<< ", device swap to host num:" << statistics_info_.device_to_host_size_
|
||||
<< ", host swap to server num:" << statistics_info_.host_to_server_size_
|
||||
<< ", server swap to host num:" << statistics_info_.server_to_host_size_
|
||||
<< ", data repeat rate:" << repeat_rate * 100 << "%, device cache hit rate:" << device_hit_rate * 100
|
||||
<< "%, host cache hit rate:" << host_hit_rate * 100 << ").";
|
||||
}
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -94,6 +94,7 @@ struct EmbeddingHostCache {
|
|||
};
|
||||
|
||||
struct PsCacheStatisticsInfo {
|
||||
size_t batch_id_count_{0};
|
||||
size_t batch_id_unique_count_{0};
|
||||
size_t device_to_host_size_{0};
|
||||
size_t host_to_device_size_{0};
|
||||
|
@ -126,7 +127,6 @@ class PsCacheManager {
|
|||
bool initialized_ps_cache() const { return initialized_ps_cache_; }
|
||||
void DoProcessData(uint32_t device_id, void *context);
|
||||
void IncreaseGraphStep(const std::string &channel_name);
|
||||
bool terminated() const { return terminated_; }
|
||||
void Finalize();
|
||||
void DumpHashTables(bool dump_device_tables = false) const;
|
||||
|
||||
|
@ -140,13 +140,14 @@ class PsCacheManager {
|
|||
std::string channel_name();
|
||||
void set_channel_name(const std::string channel_name);
|
||||
void InitParameterServer();
|
||||
void InitDataChannel();
|
||||
void AllocMemForHashTable();
|
||||
void SetLocalIdRank();
|
||||
void ProcessDataTask(uint32_t device_id, void *context);
|
||||
bool ProcessData();
|
||||
bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
|
||||
bool WaitGraphRun();
|
||||
int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device);
|
||||
bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
|
||||
bool ParseHostDataHostToDevice(size_t id);
|
||||
bool ParseHostDataDeviceToHost(size_t id);
|
||||
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
|
||||
|
@ -164,6 +165,7 @@ class PsCacheManager {
|
|||
const int *indices_addr, float *output_addr);
|
||||
bool CheckFinishInsertInitInfo() const;
|
||||
void AddEmbeddingTable() const;
|
||||
void DumpStatisticsInfo(size_t each_print_step = 1000);
|
||||
|
||||
bool initialized_ps_cache_{false};
|
||||
std::string channel_name_;
|
||||
|
@ -189,7 +191,6 @@ class PsCacheManager {
|
|||
std::atomic_bool finish_insert_init_info_{false};
|
||||
std::atomic_bool finish_init_parameter_server_{false};
|
||||
std::atomic_bool running_{false};
|
||||
std::atomic_bool terminated_{false};
|
||||
};
|
||||
|
||||
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();
|
||||
|
|
Loading…
Reference in New Issue