[bugfix] server core dump after traning

This commit is contained in:
lizhenyu 2020-12-28 18:10:37 +08:00
parent deb4f7d46f
commit 7eb49cfce7
4 changed files with 11 additions and 6 deletions

View File

@ -413,7 +413,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
}
if (need_swap_device_to_host) {
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id));
RETURN_IF_FALSE(ParseHostDataDeviceToHost());
}
}
return true;
@ -515,7 +515,7 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
return true;
}
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
bool PsCacheManager::ParseHostDataDeviceToHost() {
MS_ERROR_IF_NULL(embedding_device_cache_);
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
@ -536,8 +536,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
while (true) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids,
data_step_, graph_running_step_, &statistics_info_.host_to_server_size_);
if (index == INVALID_INDEX_VALUE) {
RETURN_IF_FALSE(WaitGraphRun());
continue;

View File

@ -150,7 +150,7 @@ class PsCacheManager {
bool WaitGraphRun();
bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
bool ParseHostDataHostToDevice(size_t id);
bool ParseHostDataDeviceToHost(size_t id);
bool ParseHostDataDeviceToHost();
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
bool HashSwapHostToDevice(const HashTableInfo &hash_info);

View File

@ -24,7 +24,7 @@ namespace mindspore {
namespace device {
void KernelRuntimeManager::ClearRuntimeResource() {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.SyncEmbeddingTable();
}
#endif

View File

@ -16,6 +16,7 @@
import threading
import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver
from mindspore._c_expression import AutoParallelContext
from mindspore._checkparam import args_type_check
@ -180,6 +181,8 @@ class _AutoParallelContext:
def get_parallel_mode(self):
"""Get parallel mode."""
self.check_context_handle()
if _is_role_pserver():
return context.ParallelMode.STAND_ALONE
return self._context_handle.get_parallel_mode()
def set_strategy_search_mode(self, auto_parallel_search_mode):
@ -242,6 +245,8 @@ class _AutoParallelContext:
def get_full_batch(self):
"""Get whether load full batch on each device."""
self.check_context_handle()
if _is_role_pserver():
return False
return self._context_handle.get_full_batch()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):