forked from OSSInnovation/mindspore
graceful shutdown in ps mode
This commit is contained in:
parent
7be664fa85
commit
241e980f06
|
@ -70,6 +70,7 @@ class ParameterServer {
|
|||
handler_(nullptr),
|
||||
func_graph_(nullptr),
|
||||
sess_(nullptr),
|
||||
running_(true),
|
||||
thread_(nullptr) {}
|
||||
~ParameterServer() = default;
|
||||
ParameterServer(const ParameterServer &) = delete;
|
||||
|
@ -106,6 +107,7 @@ class ParameterServer {
|
|||
void InitGrad(const Key &key, const GradPtr &grad);
|
||||
void InitEmbeddingTable(const Key &key,
|
||||
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes);
|
||||
void Finalize();
|
||||
void UpdateWeights();
|
||||
void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
|
||||
WeightPtr weight(const Key &key);
|
||||
|
@ -123,6 +125,7 @@ class ParameterServer {
|
|||
std::unique_ptr<ServerHandler> handler_;
|
||||
FuncGraphPtr func_graph_;
|
||||
std::shared_ptr<session::SessionBasic> sess_;
|
||||
bool running_;
|
||||
|
||||
std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
|
||||
std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
|
||||
|
@ -261,7 +264,7 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta
|
|||
template <typename T>
|
||||
void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
|
||||
::ps::KVPairs<T> *res) {
|
||||
::ps::Finalize(0, false);
|
||||
ps_->Finalize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -381,11 +384,20 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|||
grads_accum_counter_[key] = 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::Finalize() {
|
||||
running_ = false;
|
||||
apply_grads_cv_.notify_one();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ParameterServer<T>::UpdateWeights() {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); });
|
||||
apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
|
||||
if (!running_) {
|
||||
break;
|
||||
}
|
||||
|
||||
for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
|
||||
Key key = iter->first;
|
||||
|
@ -550,6 +562,8 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
Init(func_graph);
|
||||
thread_->join();
|
||||
::ps::Finalize(0, true);
|
||||
exit(1);
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
|
|
|
@ -23,9 +23,8 @@ namespace parallel {
|
|||
namespace ps {
|
||||
void Scheduler::Run() {
|
||||
::ps::Start(0);
|
||||
while (true) {
|
||||
sleep(1);
|
||||
}
|
||||
::ps::Finalize(0, true);
|
||||
exit(1);
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
|
|
|
@ -54,7 +54,7 @@ class Worker {
|
|||
|
||||
private:
|
||||
Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {}
|
||||
~Worker() { ::ps::Finalize(0, true); }
|
||||
~Worker() = default;
|
||||
Worker(const Worker &) = delete;
|
||||
Worker &operator=(const Worker &) = delete;
|
||||
|
||||
|
@ -81,7 +81,6 @@ void Worker<T>::Run() {
|
|||
MS_LOG(INFO) << "'Worker is already running.";
|
||||
return;
|
||||
}
|
||||
|
||||
::ps::Start(0);
|
||||
if (!::ps::IsWorker()) {
|
||||
MS_LOG(EXCEPTION) << "The role is not worker.";
|
||||
|
@ -121,7 +120,11 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const :
|
|||
|
||||
template <typename T>
|
||||
void Worker<T>::Finalize() {
|
||||
kv_worker_->Finalize();
|
||||
if (running_) {
|
||||
kv_worker_->Finalize();
|
||||
kv_worker_.reset();
|
||||
running_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -155,7 +155,7 @@ void WorkerProxy<T>::Finalize() {
|
|||
kvs.vals.push_back(0.0f);
|
||||
Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
|
||||
obj_->WaitRequest(ts);
|
||||
::ps::Finalize(0, false);
|
||||
::ps::Finalize(0, true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "frontend/parallel/ps/common.h"
|
||||
#include "frontend/parallel/ps/util.h"
|
||||
#include "frontend/parallel/ps/worker.h"
|
||||
#endif
|
||||
|
||||
#if (ENABLE_GE || ENABLE_D)
|
||||
|
@ -949,7 +950,13 @@ void ClearResAtexit() {
|
|||
pynative::ClearPyNativeSession();
|
||||
session::ClearPythonParasMap();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (mindspore::parallel::ps::Util::IsParamServerMode()) {
|
||||
if (parallel::ps::Util::IsRoleOfWorker()) {
|
||||
parallel::ps::Worker<float>::GetInstance().Finalize();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
ad::g_k_prims.clear();
|
||||
|
||||
abstract::ClearPrimEvaluatorMap();
|
||||
|
|
Loading…
Reference in New Issue