fix process can not exit normaly when exception has been thrown

This commit is contained in:
lizhenyu 2022-02-24 17:37:05 +08:00
parent bbcfbce9e0
commit f378fd3302
3 changed files with 35 additions and 15 deletions

View File

@ -360,6 +360,10 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) {
}
void PsCacheManager::Finalize() {
if (finalized_) {
return;
}
SyncEmbeddingTable();
running_ = false;
@ -369,6 +373,8 @@ void PsCacheManager::Finalize() {
if (process_data_thread_.joinable()) {
process_data_thread_.join();
}
finalized_ = true;
}
bool PsCacheManager::ProcessData() {

View File

@ -207,6 +207,7 @@ class PsCacheManager {
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};
std::atomic_bool finalized_{false};
bool finish_embedding_table_sync_{false};
bool device_need_wait_graph_{false};
bool host_need_wait_graph_{false};

View File

@ -1685,6 +1685,14 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
}
#if ((defined ENABLE_CPU) && (!defined _WIN32))
namespace {
// Finalize ps cache module before throw an exception.
void FinalizePsCache(const std::string &exception) {
ps::ps_cache_instance.Finalize();
MS_LOG(EXCEPTION) << exception;
}
} // namespace
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph &graph,
AnfNodePtr *const first_cache_input_index,
size_t *const first_cache_size) {
@ -1711,8 +1719,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph &graph,
AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
MS_EXCEPTION_IF_NULL(cnode);
if (!cnode->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index should be a CNode but got "
<< cnode->fullname_with_scope();
FinalizePsCache("The embeddingLookup whose input index should be a CNode but got " +
cnode->fullname_with_scope());
}
auto input_index_node_name = AnfAlgo::GetCNodeName(cnode);
if (input_index_node_name != kGetNextOpName) {
@ -1721,8 +1729,9 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph &graph,
(full_batch && (input_index_node_name != kMinimumOpName))) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
<< ") cache is from " << cnode->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
FinalizePsCache(
"The embeddingLookup whose input index isn't from dataset doesn't support cache in parameter server training "
"mode.");
}
}
*first_cache_input_index = cnode;
@ -1742,7 +1751,7 @@ void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(pre_node.first);
}
if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode";
FinalizePsCache("The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode");
}
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
@ -1752,8 +1761,9 @@ void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(pre_node.first);
}
if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) {
MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices "
"value can not be changed before delivering to kernel[Unique] in parameter server cache mode.";
FinalizePsCache(
"The input indices of kernel[Unique] must be produced from dataset directly and the indices value can not be "
"changed before delivering to kernel[Unique] in parameter server cache mode.");
}
}
@ -1789,25 +1799,28 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph &g
if (cnode == first_cache_input_index) {
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
"same time when one of them enables cache in parameter server training mode.";
FinalizePsCache(
"All the embeddingLookups whose input indices are from dataset must enable cache at the same time when one "
"of them enables cache in parameter server training mode.");
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
if (size != first_cache_size) {
MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
<< ") is not the same as other embeddingLookup cache size(" << first_cache_size << ").";
MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
FinalizePsCache("The cache sizes of embeddingLookups are not the same in parameter server training mode.");
}
} else if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< cnode->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
FinalizePsCache(
"The embeddingLookup whose input index isn't from dataset doesn't support cache in parameter server training "
"mode.");
} else if (cnode->isa<CNode>() && (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
MS_LOG(EXCEPTION) << "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
"the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
"context setting in parameter server training mode.";
FinalizePsCache(
"All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
"the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
"context setting in parameter server training mode.");
}
}
}