forked from mindspore-Ecosystem/mindspore
[bugfix] server core dump after traning
This commit is contained in:
parent
deb4f7d46f
commit
7eb49cfce7
|
@ -413,7 +413,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
|
|||
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
|
||||
}
|
||||
if (need_swap_device_to_host) {
|
||||
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id));
|
||||
RETURN_IF_FALSE(ParseHostDataDeviceToHost());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -515,7 +515,7 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
|
||||
bool PsCacheManager::ParseHostDataDeviceToHost() {
|
||||
MS_ERROR_IF_NULL(embedding_device_cache_);
|
||||
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
|
||||
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
|
||||
|
@ -536,8 +536,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
|
|||
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
|
||||
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
|
||||
while (true) {
|
||||
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
|
||||
graph_running_step_, &statistics_info_.host_to_server_size_);
|
||||
auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids,
|
||||
data_step_, graph_running_step_, &statistics_info_.host_to_server_size_);
|
||||
if (index == INVALID_INDEX_VALUE) {
|
||||
RETURN_IF_FALSE(WaitGraphRun());
|
||||
continue;
|
||||
|
|
|
@ -150,7 +150,7 @@ class PsCacheManager {
|
|||
bool WaitGraphRun();
|
||||
bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index);
|
||||
bool ParseHostDataHostToDevice(size_t id);
|
||||
bool ParseHostDataDeviceToHost(size_t id);
|
||||
bool ParseHostDataDeviceToHost();
|
||||
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
|
||||
bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
|
||||
bool HashSwapHostToDevice(const HashTableInfo &hash_info);
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace device {
|
||||
void KernelRuntimeManager::ClearRuntimeResource() {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps::ps_cache_instance.SyncEmbeddingTable();
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import threading
|
||||
import mindspore.context as context
|
||||
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||
from mindspore.parallel._ps_context import _is_role_pserver
|
||||
from mindspore._c_expression import AutoParallelContext
|
||||
from mindspore._checkparam import args_type_check
|
||||
|
||||
|
@ -180,6 +181,8 @@ class _AutoParallelContext:
|
|||
def get_parallel_mode(self):
|
||||
"""Get parallel mode."""
|
||||
self.check_context_handle()
|
||||
if _is_role_pserver():
|
||||
return context.ParallelMode.STAND_ALONE
|
||||
return self._context_handle.get_parallel_mode()
|
||||
|
||||
def set_strategy_search_mode(self, auto_parallel_search_mode):
|
||||
|
@ -242,6 +245,8 @@ class _AutoParallelContext:
|
|||
def get_full_batch(self):
|
||||
"""Get whether load full batch on each device."""
|
||||
self.check_context_handle()
|
||||
if _is_role_pserver():
|
||||
return False
|
||||
return self._context_handle.get_full_batch()
|
||||
|
||||
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
|
||||
|
|
Loading…
Reference in New Issue