diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.cc b/mindspore/ccsrc/ps/core/communicator/communicator_base.cc index 7ce0a5332aa..94b0b8d3e1b 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.cc +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.cc @@ -20,6 +20,11 @@ namespace mindspore { namespace ps { namespace core { +CommunicatorBase::~CommunicatorBase() { + running_ = false; + Join(); +} + bool CommunicatorBase::SendResponse(const void *rsp_data, size_t rsp_len, std::shared_ptr msg_handler) { // The rsp_len could be 0 because of ProtoBuffer's feature. if (rsp_data == nullptr || msg_handler == nullptr) { diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.h b/mindspore/ccsrc/ps/core/communicator/communicator_base.h index 0ca32f542da..15733fc993e 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.h +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.h @@ -42,9 +42,9 @@ class CommunicatorBase { using OnNodeEventCallback = std::function; using TcpMsgCallback = std::function conn, std::shared_ptr meta, DataPtr data, size_t size)>; - CommunicatorBase() = default; + CommunicatorBase() : running_(false) {} - virtual ~CommunicatorBase() = default; + virtual ~CommunicatorBase(); virtual bool Start() = 0; virtual bool Stop() = 0; @@ -59,6 +59,7 @@ class CommunicatorBase { protected: std::unordered_map msg_callbacks_; std::thread running_thread_; + bool running_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/communicator/http_communicator.cc b/mindspore/ccsrc/ps/core/communicator/http_communicator.cc index cdc8556b00f..fbcc7282ff9 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_communicator.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_communicator.cc @@ -31,10 +31,8 @@ bool HttpCommunicator::Start() { MS_LOG(INFO) << "Http communicator started."; running_thread_ = std::thread([&]() { - try { - http_server_->Wait(); - } catch (const std::exception &e) { - MsException::Instance().SetException(); + while (running_) { + std::this_thread::yield(); } }); return true; @@ -42,7 +40,9 @@ bool HttpCommunicator::Start() { bool HttpCommunicator::Stop() { MS_EXCEPTION_IF_NULL(http_server_); - return http_server_->Stop(); + bool res = http_server_->Stop(); + running_ = false; + return res; } void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { diff --git a/mindspore/ccsrc/ps/core/communicator/http_server.cc b/mindspore/ccsrc/ps/core/communicator/http_server.cc index 60a78335f32..57a7d39efda 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_server.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_server.cc @@ -115,13 +115,17 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio return true; } -bool HttpServer::Start() { +bool HttpServer::Start(bool is_detach) { MS_LOG(INFO) << "Start http server!"; for (size_t i = 0; i < thread_num_; i++) { auto http_request_handler = std::make_shared(); http_request_handler->Initialize(fd_, request_handlers_); http_request_handlers.push_back(http_request_handler); - worker_threads_.emplace_back(std::make_shared(&HttpRequestHandler::Run, http_request_handler)); + auto thread = std::make_shared(&HttpRequestHandler::Run, http_request_handler); + if (is_detach) { + thread->detach(); + } + worker_threads_.emplace_back(thread); } return true; } diff --git a/mindspore/ccsrc/ps/core/communicator/http_server.h b/mindspore/ccsrc/ps/core/communicator/http_server.h index 62736822ed1..e55c00320d6 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_server.h +++ b/mindspore/ccsrc/ps/core/communicator/http_server.h @@ -64,7 +64,7 @@ class HttpServer { // Return: true if success, false if failed, check log to find failure reason bool RegisterRoute(const std::string &url, OnRequestReceive *func); - bool Start(); + bool Start(bool is_detach = true); bool Wait(); bool Stop(); diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index befbcdf26f0..a0675ff9a29 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -65,7 +65,6 @@ class TcpCommunicator : public CommunicatorBase { public: explicit TcpCommunicator(const std::shared_ptr &task_executor, ServerNode *node) : task_executor_(task_executor), - running_(false), server_num_(0), worker_num_(0), scheduler_ip_(""), @@ -109,7 +108,6 @@ class TcpCommunicator : public CommunicatorBase { private: std::shared_ptr task_executor_; - bool running_; TcpMsgCallback tcp_msg_callback_; OnNodeEventCallback event_callback_; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index d5fb6f54ccb..690fb982ca3 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -429,6 +429,7 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr resp) { nlohmann::json js; js["message"] = "Cluster begin to scale out."; resp->AddRespString(js.dump()); + resp->AddRespHeadParam("Content_Type", "application/json"); resp->SetRespCode(HTTP_OK); resp->SendResponse(); @@ -507,6 +508,7 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr resp) { nlohmann::json js; js["message"] = "Cluster begin to scale in."; resp->AddRespString(js.dump()); + resp->AddRespHeadParam("Content_Type", "application/json"); resp->SetRespCode(HTTP_OK); resp->SendResponse(); @@ -543,6 +545,7 @@ void SchedulerNode::ProcessGetNodesInfo(std::shared_ptr resp } resp->AddRespString(js.dump()); + resp->AddRespHeadParam("Content_Type", "application/json"); resp->SetRespCode(HTTP_OK); resp->SendResponse(); @@ -562,6 +565,7 @@ void SchedulerNode::ProcessGetClusterState(std::shared_ptr r js["cluster_state"] = CommUtil::ClusterStateToString(cluster_state); resp->AddRespString(js.dump()); + resp->AddRespHeadParam("Content_Type", "application/json"); resp->SetRespCode(HTTP_OK); resp->SendResponse(); @@ -601,6 +605,13 @@ RequestProcessResult SchedulerNode::CheckIfNodeIdLegal(const std::vectorInitServer(); - http_server_->Start(); + http_server_->Start(false); restful_thread_ = std::make_unique([&]() { http_server_->Wait(); }); } diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index fc176d53bf5..09fbc0854a7 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -77,6 +77,7 @@ void ServerNode::CreateTcpServer() { MS_LOG(INFO) << "The server node start a tcp server!"; this->server_->Start(); }); + server_thread_->detach(); } void ServerNode::Initialize() { @@ -158,20 +159,13 @@ bool ServerNode::Stop() { if (!is_already_stopped_.load()) { is_already_stopped_ = true; is_finish_ = true; - if (heart_beat_thread_->joinable()) { - heart_beat_thread_->join(); - } client_to_scheduler_->Stop(); if (!connected_nodes_.empty()) { for (auto &connected_node : connected_nodes_) { connected_node.second->Stop(); } } - if (client_to_scheduler_thread_->joinable()) { - client_to_scheduler_thread_->join(); - } server_->Stop(); - server_thread_->join(); } return true; } diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 073421d6feb..80446f2a0af 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -68,6 +68,7 @@ void WorkerNode::CreateTcpServer() { MS_LOG(INFO) << "The worker node start a tcp server!"; server_->Start(); }); + server_thread_->detach(); } bool WorkerNode::Stop() { @@ -82,7 +83,6 @@ bool WorkerNode::Stop() { } } server_->Stop(); - server_thread_->join(); is_already_stopped_ = true; } return true;