diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 2a32904ab05..429cf59c3c1 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -70,6 +70,7 @@ class ParameterServer { handler_(nullptr), func_graph_(nullptr), sess_(nullptr), + running_(true), thread_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; @@ -106,6 +107,7 @@ class ParameterServer { void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, const std::shared_ptr>>> &shapes); + void Finalize(); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); @@ -123,6 +125,7 @@ class ParameterServer { std::unique_ptr handler_; FuncGraphPtr func_graph_; std::shared_ptr sess_; + bool running_; std::unordered_map> optimizers_; std::unordered_map optim_inputs_shape_; @@ -261,7 +264,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta template void ParameterServer::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - ::ps::Finalize(0, false); + ps_->Finalize(); } template @@ -381,11 +384,20 @@ void ParameterServer::InitEmbeddingTable( grads_accum_counter_[key] = 0; } +template +void ParameterServer::Finalize() { + running_ = false; + apply_grads_cv_.notify_one(); +} + template void ParameterServer::UpdateWeights() { while (true) { std::unique_lock lock(mutex_); - apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; }); + if (!running_) { + break; + } for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { Key key = iter->first; @@ -550,6 +562,8 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { } Init(func_graph); thread_->join(); + ::ps::Finalize(0, true); + exit(1); } } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc index 274b7259b09..04c259487fa 100755 --- a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc @@ -23,9 +23,8 @@ namespace parallel { namespace ps { void Scheduler::Run() { ::ps::Start(0); - while (true) { - sleep(1); - } + ::ps::Finalize(0, true); + exit(1); } } // namespace ps } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index abdc046ffe0..d7f0bb6df52 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -54,7 +54,7 @@ class Worker { private: Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} - ~Worker() { ::ps::Finalize(0, true); } + ~Worker() = default; Worker(const Worker &) = delete; Worker &operator=(const Worker &) = delete; @@ -81,7 +81,6 @@ void Worker::Run() { MS_LOG(INFO) << "'Worker is already running."; return; } - ::ps::Start(0); if (!::ps::IsWorker()) { MS_LOG(EXCEPTION) << "The role is not worker."; @@ -121,7 +120,11 @@ void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : template void Worker::Finalize() { - kv_worker_->Finalize(); + if (running_) { + kv_worker_->Finalize(); + kv_worker_.reset(); + running_ = false; + } } template diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index c8bd27c067a..6ac7f6322da 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -155,7 +155,7 @@ void WorkerProxy::Finalize() { kvs.vals.push_back(0.0f); Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); obj_->WaitRequest(ts); - ::ps::Finalize(0, false); + ::ps::Finalize(0, true); } template diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index b1c3f2db036..dee864d085e 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -45,6 +45,7 @@ #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/util.h" +#include "frontend/parallel/ps/worker.h" #endif #if (ENABLE_GE || ENABLE_D) @@ -949,7 +950,13 @@ void ClearResAtexit() { pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (mindspore::parallel::ps::Util::IsParamServerMode()) { + if (parallel::ps::Util::IsRoleOfWorker()) { + parallel::ps::Worker::GetInstance().Finalize(); + } + } +#endif ad::g_k_prims.clear(); abstract::ClearPrimEvaluatorMap();