add node recovery

This commit is contained in:
chendongsheng 2021-06-06 22:34:30 +08:00
parent 26c7d274c9
commit bfaab72934
28 changed files with 540 additions and 53 deletions

View File

@ -362,7 +362,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Set federated learning client learning rate.")
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
"Set scheduler manage port used to scale out/in.")
.def("set_enable_ssl", &PSContext::enable_ssl, "Set PS SSL mode enabled or disabled.");
.def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
.def("set_config_file_path", &PSContext::set_config_file_path,
"Set configuration files required by the communication layer.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())

View File

@ -34,6 +34,9 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "core/leader_scaler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/follower_scaler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/file_configuration.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/recovery_base.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/node_recovery.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_recovery.cc")
endif()
if(NOT ENABLE_D)

View File

@ -95,6 +95,19 @@ constexpr uint32_t kCheckRegisteredRetryCount = 30;
// The timeout interval for judging whether all nodes are successfully registered.
constexpr uint32_t kCheckRegisteredIntervalInMs = 1000;
// The type of persistent storage, currently only supports file storage.
constexpr char kStoreType[] = "storge_type";
// The file used to storage metadata.
constexpr char kStoreFilePath[] = "storge_file_path";
// 1 indicates that the persistent storage type is file.
constexpr char kFileStorage[] = "1";
// The recovery key of json_config.
constexpr char kKeyRecovery[] = "recovery";
constexpr char kRecoveryWorkerNum[] = "worker_num";
constexpr char kRecoveryServerNum[] = "server_num";
constexpr char kRecoverySchedulerIp[] = "scheduler_ip";
constexpr char kRecoverySchedulerPort[] = "scheduler_port";
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;

View File

@ -15,6 +15,7 @@
*/
#include "ps/core/abstract_node.h"
#include "ps/core/node_recovery.h"
namespace mindspore {
namespace ps {
@ -54,8 +55,7 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const
// scheduler is alive
UpdateSchedulerTime();
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
<< " registered scheduler success!";
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << " registered scheduler success!";
}
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
@ -143,8 +143,9 @@ void AbstractNode::BroadcastEvent(const uint32_t &event) {
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(),
event_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send event timeout!";
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send event timeout!";
return;
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -377,6 +378,18 @@ int32_t AbstractNode::worker_num() const { return worker_num_; }
int32_t AbstractNode::server_num() const { return server_num_; }
void AbstractNode::set_worker_num(const int32_t &worker_num) { worker_num_ = worker_num; }
void AbstractNode::set_server_num(const int32_t &server_num) { server_num_ = server_num; }
std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
void AbstractNode::set_scheduler_ip(const std::string &scheduler_ip) { scheduler_ip_ = scheduler_ip; }
uint16_t AbstractNode::scheduler_port() const { return scheduler_port_; }
void AbstractNode::set_scheduler_port(const uint16_t &scheduler_port) { scheduler_port_ = scheduler_port; }
ClusterState AbstractNode::cluster_state() const { return current_cluster_state_; }
void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
@ -450,10 +463,13 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
wait_start_cond_.notify_all();
}
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
is_ready_ = true;
wait_start_cond_.notify_all();
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
if (node_recovery_ == nullptr) {
MS_LOG(INFO) << "The recovery is disable.";
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
is_ready_ = true;
wait_start_cond_.notify_all();
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
}
}
}
@ -489,6 +505,7 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
MS_EXCEPTION_IF_NULL(data);
if (is_current_node_scale_in_) {
MS_LOG(WARNING) << "Trigger cluster scale in done event.";
node_info_.rank_id_ = UINT32_MAX;
OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
return;
}
@ -639,9 +656,7 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
}
bool AbstractNode::InitClientToScheduler() {
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
uint16_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_);
client_to_scheduler_->SetMessageCallback(
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
try {
@ -851,6 +866,7 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) {
node_info_.node_role_ = role;
node_info_.ip_ = server_->BoundIp();
node_info_.port_ = server_->BoundPort();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
<< ", the port:" << server_->BoundPort();
@ -859,6 +875,20 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) {
void AbstractNode::InitNodeNum() {
worker_num_ = PSContext::instance()->cluster_config().initial_worker_num;
server_num_ = PSContext::instance()->cluster_config().initial_server_num;
scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
<< ", the scheduler ip:" << scheduler_ip_ << ", the scheduler port:" << scheduler_port_;
}
bool AbstractNode::Recover() {
if (config_->Exists(kKeyRecovery)) {
MS_LOG(INFO) << "The node is support recovery.";
node_recovery_ = std::make_unique<NodeRecovery>(this);
node_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
return node_recovery_->Recover();
}
return false;
}
void AbstractNode::OnEventCallback(const ClusterEvent &event) {

View File

@ -29,6 +29,7 @@
#include "ps/core/follower_scaler.h"
#include "utils/ms_exception.h"
#include "ps/constants.h"
#include "ps/core/recovery_base.h"
namespace mindspore {
namespace ps {
@ -44,7 +45,11 @@ class AbstractNode : public Node {
server_thread_(nullptr),
worker_num_(-1),
server_num_(-1),
is_current_node_scale_in_(false) {}
is_current_node_scale_in_(false),
follower_scaler_(nullptr),
node_recovery_(nullptr),
scheduler_ip_(""),
scheduler_port_(0) {}
~AbstractNode() override = default;
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
@ -105,6 +110,15 @@ class AbstractNode : public Node {
int32_t worker_num() const;
int32_t server_num() const;
void set_worker_num(const int32_t &worker_num);
void set_server_num(const int32_t &server_num);
std::string scheduler_ip() const;
void set_scheduler_ip(const std::string &scheduler_ip);
uint16_t scheduler_port() const;
void set_scheduler_port(const uint16_t &scheduler_port);
ClusterState cluster_state() const;
protected:
@ -159,6 +173,8 @@ class AbstractNode : public Node {
void InitNodeInfo(const NodeRole &role);
// Initialize worker num and server num by cluster config.
void InitNodeNum();
// Node recover by cluster config.
bool Recover();
// Trigger the callback corresponding to the event.
void OnEventCallback(const ClusterEvent &event);
@ -221,6 +237,14 @@ class AbstractNode : public Node {
// Scaler for worker/server node.
std::unique_ptr<FollowerScaler> follower_scaler_;
// Recovery for worker/server node.
std::unique_ptr<RecoveryBase> node_recovery_;
// The ip of scheduler.
std::string scheduler_ip_;
// The port of scheduler.
uint16_t scheduler_port_;
};
} // namespace core
} // namespace ps

View File

@ -47,6 +47,9 @@ class Configuration {
// Put configuration data to database or config file.
virtual void Put(const std::string &key, const std::string &defaultvalue) = 0;
// Determine whether the configuration item exists.
virtual bool Exists(const std::string &key) = 0;
};
} // namespace core
} // namespace ps

View File

@ -39,7 +39,7 @@ bool FileConfiguration::Initialize() {
std::string FileConfiguration::Get(const std::string &key, const std::string &defaultvalue) const {
if (!js.contains(key)) {
MS_LOG(WARNING) << "The key:" << key << " is not exit.";
MS_LOG(WARNING) << "The key:" << key << " is not exist.";
return defaultvalue;
}
std::string res = js.at(key);
@ -54,6 +54,12 @@ void FileConfiguration::Put(const std::string &key, const std::string &value) {
output_file.close();
}
bool FileConfiguration::Exists(const std::string &key) {
if (!js.contains(key)) {
return false;
}
return true;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -56,6 +56,8 @@ class FileConfiguration : public Configuration {
void Put(const std::string &key, const std::string &value) override;
bool Exists(const std::string &key) override;
private:
// The path of the configuration file.
std::string file_path_;

View File

@ -36,6 +36,7 @@
#include "ps/core/node_info.h"
#include "ps/core/communicator/tcp_client.h"
#include "ps/core/communicator/tcp_server.h"
#include "ps/core/file_configuration.h"
namespace mindspore {
namespace ps {
@ -67,7 +68,7 @@ class Node {
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
const void *, size_t size, const uint32_t &timeout = kTimeoutInSeconds);
protected:
bool WaitForStart(const uint32_t &timeout);
@ -105,6 +106,15 @@ class Node {
// Worker and server receive the node state and cluster state from the scheduler.
NodeState current_node_state_;
ClusterState current_cluster_state_;
// Configuration file,The format is as follows
//{
// "recovery": {
// "storage_type": 1,
// "storge_file_path": "/home/cds/config.json"
// }
// }
std::unique_ptr<Configuration> config_;
};
} // namespace core
} // namespace ps

View File

@ -36,7 +36,7 @@ enum class ClusterEvent {
};
struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0), is_alive(false) {}
// ip
std::string ip_;
// the port of this node
@ -47,6 +47,9 @@ struct NodeInfo {
NodeRole node_role_;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
uint32_t rank_id_;
// After the node registration is successful, it is alive.If the node's heartbeat times out, then it is not alive
bool is_alive;
};
} // namespace core
} // namespace ps

View File

@ -42,7 +42,21 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message) {
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
rank_id = ++next_server_rank_id_;
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::SERVER;
if (res) {
MS_LOG(INFO) << "The server node id:" << item.first << " rank id:" << rank_id << " is not alive.";
rank_id = item.second.rank_id_;
}
return res;
});
if (rank_it == registered_nodes_info_.end()) {
rank_id = ++next_server_rank_id_;
} else {
registered_nodes_info_.erase((*rank_it).first);
}
if (rank_id >= meta_data_->server_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
rank_id = UINT_MAX;
@ -54,13 +68,29 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message) {
node_info.rank_id_ = rank_id;
node_info.ip_ = ip;
node_info.port_ = port;
node_info.is_alive = true;
registered_nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
<< " assign rank id:" << rank_id;
} else if (register_message.role() == NodeRole::WORKER) {
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
rank_id = ++next_worker_rank_id_;
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
if (res) {
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
rank_id = item.second.rank_id_;
}
return res;
});
if (rank_it == registered_nodes_info_.end()) {
rank_id = ++next_worker_rank_id_;
} else {
registered_nodes_info_.erase((*rank_it).first);
}
if (rank_id >= meta_data_->worker_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
rank_id = UINT_MAX;
@ -72,6 +102,7 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message) {
node_info.rank_id_ = rank_id;
node_info.ip_ = ip;
node_info.port_ = port;
node_info.is_alive = true;
registered_nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
}
@ -85,12 +116,6 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
heartbeats_[node_id] = current_time;
}
void NodeManager::UpdateNodeScaleInState(const std::string &node_id) { heartbeats_scale_in_nodes_.insert(node_id); }
bool NodeManager::CheckNodesScaluOutState() { return SizeToInt(heartbeats_scale_out_nodes_.size()) == total_node_num_; }
bool NodeManager::CheckNodesScaleInState() { return SizeToInt(heartbeats_scale_in_nodes_.size()) == total_node_num_; }
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
std::vector<ServersMeta> servers_meta_list;
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
@ -115,12 +140,15 @@ void NodeManager::UpdateCluster() {
if (registered_nodes_info_.count(it->first)) {
MS_LOG(WARNING) << "The node id:" << it->first << " is timeout!";
timeout_nodes_info_[it->first] = registered_nodes_info_[it->first];
registered_nodes_info_[it->first].is_alive = false;
}
}
}
if (!timeout_nodes_info_.empty()) {
UpdateClusterState(ClusterState::NODE_TIMEOUT);
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
heartbeats_.erase(it->first);
finish_nodes_id_.insert(it->first);
}
}
@ -149,7 +177,11 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
bool NodeManager::IsAllNodesRegistered() { return SizeToInt(registered_nodes_info_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesRegistered() {
int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
[](auto item) { return item.second.is_alive == true; });
return num == total_node_num_;
}
bool NodeManager::IsAllNodesFinished() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
@ -175,7 +207,7 @@ void NodeManager::UpdateNodeState(const NodeState &state) {
void NodeManager::UpdateClusterState(const ClusterState &state) {
std::lock_guard<std::mutex> lk(cluster_mutex_);
MS_LOG(INFO) << "[state]: Scheduler change state from:" << CommUtil::ClusterStateToString(cluster_state_) << " to "
<< state;
<< CommUtil::ClusterStateToString(state);
cluster_state_ = state;
}

View File

@ -30,6 +30,8 @@
#include <vector>
#include <condition_variable>
#include <unordered_set>
#include <deque>
#include <algorithm>
#include "ps/core/node.h"
#include "utils/log_adapter.h"
@ -57,9 +59,6 @@ class NodeManager {
uint32_t NextRankId(const RegisterMessage &register_message);
void UpdateHeartbeat(const std::string &node_id);
bool CheckNodesScaluOutState();
void UpdateNodeScaleInState(const std::string &node_id);
bool CheckNodesScaleInState();
std::vector<ServersMeta> FetchServersMeta();
void UpdateCluster();
@ -126,9 +125,6 @@ class NodeManager {
std::mutex heartbeat_mutex_;
std::unordered_map<std::string, timeval> heartbeats_;
std::unordered_set<std::string> heartbeats_finish_nodes_;
std::unordered_set<std::string> heartbeats_scale_out_nodes_;
std::unordered_set<std::string> heartbeats_scale_in_nodes_;
// timeout nodes
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
std::unordered_set<std::string> finish_nodes_id_;

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/node_recovery.h"
namespace mindspore {
namespace ps {
namespace core {
bool NodeRecovery::Recover() {
if (recovery_storage_ == nullptr) {
return false;
}
// 1. recover worker num
if (recovery_storage_->Exists(kRecoveryWorkerNum)) {
int32_t worker_num = std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, 10);
node_->set_worker_num(worker_num);
} else {
node_->set_worker_num(PSContext::instance()->cluster_config().initial_worker_num);
}
// 2. recover server num
if (recovery_storage_->Exists(kRecoveryServerNum)) {
int32_t server_num = std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, 10);
node_->set_server_num(server_num);
} else {
node_->set_server_num(PSContext::instance()->cluster_config().initial_server_num);
}
// 3. recover scheduler ip
if (recovery_storage_->Exists(kRecoverySchedulerIp)) {
std::string scheduler_ip = recovery_storage_->Get(kRecoverySchedulerIp, "");
node_->set_scheduler_ip(scheduler_ip);
} else {
node_->set_scheduler_ip(PSContext::instance()->cluster_config().scheduler_host);
}
// 4. recover scheduler port
if (recovery_storage_->Exists(kRecoverySchedulerPort)) {
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, 10);
node_->set_scheduler_port(scheduler_port);
} else {
node_->set_scheduler_port(PSContext::instance()->cluster_config().scheduler_port);
}
MS_LOG(INFO) << "The worker num:" << node_->worker_num() << ", the server num:" << node_->server_num()
<< ", the scheduler ip:" << node_->scheduler_ip() << ", the scheduler port:" << node_->scheduler_port();
return true;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_NODE_RECOVERY_H_
#define MINDSPORE_CCSRC_PS_CORE_NODE_RECOVERY_H_
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "ps/core/recovery_base.h"
#include "ps/constants.h"
#include "utils/log_adapter.h"
#include "ps/core/file_configuration.h"
#include "ps/ps_context.h"
#include "ps/core/abstract_node.h"
namespace mindspore {
namespace ps {
namespace core {
// The class helps worker/server node to do recovery operation for the cluster.
class NodeRecovery : public RecoveryBase {
public:
explicit NodeRecovery(AbstractNode *const node) : node_(node) {}
~NodeRecovery() override = default;
bool Recover() override;
private:
// The node_ will only be instantiated with worker/server node.
AbstractNode *const node_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_NODE_RECOVERY_H_

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/recovery_base.h"
namespace mindspore {
namespace ps {
namespace core {
void RecoveryBase::Initialize(const std::string &config_json) {
nlohmann::json recovery_config;
try {
recovery_config = nlohmann::json::parse(config_json);
} catch (nlohmann::json::exception &e) {
MS_LOG(ERROR) << "Parse the json:" << config_json;
}
MS_LOG(INFO) << "The scheduelr is support recovery.";
std::string storage_file_path = "";
std::string type = recovery_config.at(kStoreType);
if (type == kFileStorage) {
storage_type_ = StorageType::kFileStorage;
storage_file_path = recovery_config.at(kStoreFilePath);
if (storage_file_path == "") {
MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of "
"the file must be configured";
}
recovery_storage_ = std::make_unique<FileConfiguration>(storage_file_path);
if (!recovery_storage_->Initialize()) {
MS_LOG(INFO) << "The storage file path " << storage_file_path << " is empty.";
}
}
MS_LOG(INFO) << "The storage type is:" << storage_type_ << ", the storage file path is:" << storage_file_path;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_RECOVERY_BASE_H_
#define MINDSPORE_CCSRC_PS_CORE_RECOVERY_BASE_H_
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "ps/constants.h"
#include "utils/log_adapter.h"
#include "ps/core/file_configuration.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace ps {
namespace core {
enum class StorageType : int { kFileStorage = 1 };
// RecoveryBase is used to parse configuration items related to recovery.
// It is the base class of SchedulerRecovery and NodeRecovery.
class RecoveryBase {
public:
RecoveryBase() : recovery_storage_(nullptr), storage_type_(StorageType::kFileStorage) {}
virtual ~RecoveryBase() = default;
// Initialize the recovery configuration item and get the storage type of recovery.
void Initialize(const std::string &json_config);
// The node needs to recover metadata information when it starts.
virtual bool Recover() = 0;
protected:
// Persistent storage used to save metadata.
std::unique_ptr<Configuration> recovery_storage_;
// Storage type for recovery,Currently only supports storage of file types
StorageType storage_type_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_RECOVERY_BASE_H_

View File

@ -133,17 +133,21 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
if (node_manager_.IsAllNodesRegistered()) {
is_ready_ = true;
MS_LOG(INFO) << "Scheduler send meta data to worker/server.";
auto node_infos = node_manager_.nodes_info();
if (node_manager_.GetClusterState() == ClusterState::ClUSTER_STARTING) {
node_infos.clear();
node_infos = node_manager_.registered_nodes_info();
MS_LOG(INFO) << "All nodes is registered, scheduler send meta data to worker/server.";
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
auto nodes = node_manager_.nodes_info();
for (const auto &id : scale_in_node_ids_) {
MS_LOG(INFO) << "The scheduler send metadata to scale in node:" << id;
auto scale_in_client = GetOrCreateClient(nodes[id]);
SendMetadata(scale_in_client, nodes[id].rank_id_);
}
}
node_manager_.UpdateNodesInfo();
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
SendMetadata(client, kvs.second.rank_id_);
}
node_manager_.UpdateNodesInfo();
wait_start_cond_.notify_all();
}
}
@ -341,8 +345,9 @@ void SchedulerNode::SendEvent(const std::shared_ptr<TcpClient> &client, const ui
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
event_resp_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send event resp timeout!";
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send event resp timeout!";
return;
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -463,6 +468,7 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
node_manager_.set_worker_num(total_worker_num);
node_manager_.set_server_num(total_server_num);
node_manager_.set_total_node_num(total_worker_num + total_server_num);
node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
auto node_infos = node_manager_.nodes_info();
node_manager_.ResetMetadata();
@ -500,20 +506,20 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
return;
}
std::vector<std::string> scale_in_node_ids;
status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids);
scale_in_node_ids_.clear();
status = resp->ParseNodeIdsFromKey(kNodesIds, &scale_in_node_ids_);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
status = CheckIfNodeIdLegal(scale_in_node_ids);
status = CheckIfNodeIdLegal(scale_in_node_ids_);
if (status != RequestProcessResultCode::kSuccess) {
resp->ErrorResponse(HTTP_BADREQUEST, status);
return;
}
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids;
MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids_;
std::unordered_map<std::string, bool> scale_in_nodes;
@ -521,7 +527,7 @@ void SchedulerNode::ProcessScaleIn(std::shared_ptr<HttpMessageHandler> resp) {
int32_t scale_server_num = 0;
auto node_infos = node_manager_.nodes_info();
node_manager_.ResetMetadata();
for (auto const &val : scale_in_node_ids) {
for (auto const &val : scale_in_node_ids_) {
if (node_infos.count(val)) {
scale_in_nodes[val] = true;
NodeInfo info = node_infos[val];

View File

@ -39,6 +39,7 @@
#include "ps/core/cluster_metadata.h"
#include "ps/core/communicator/http_server.h"
#include "ps/core/leader_scaler.h"
#include "ps/core/recovery_base.h"
namespace mindspore {
namespace ps {
@ -53,7 +54,8 @@ class SchedulerNode : public Node {
http_server_(nullptr),
client_thread_(nullptr),
is_client_started_(false),
leader_scaler_(nullptr) {}
leader_scaler_(nullptr),
scheduler_recovery_(nullptr) {}
~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
@ -65,6 +67,7 @@ class SchedulerNode : public Node {
private:
void Initialize();
void InitCommandHandler();
void CreateTcpServer();
void StartUpdateClusterStateTimer();
@ -144,6 +147,12 @@ class SchedulerNode : public Node {
std::unique_ptr<LeaderScaler> leader_scaler_;
std::unordered_map<std::string, OnRequestReceive> callbacks_;
// Used to persist and obtain metadata information for scheduler.
std::unique_ptr<RecoveryBase> scheduler_recovery_;
// The node id of scale in nodes.
std::vector<std::string> scale_in_node_ids_;
};
} // namespace core
} // namespace ps

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/scheduler_recovery.h"
namespace mindspore {
namespace ps {
namespace core {
void SchedulerRecovery::Persist(const std::string &key, const std::string &value) {
recovery_storage_->Put(key, value);
}
std::string SchedulerRecovery::GetMetadata(const std::string &key) { return recovery_storage_->Get(key, ""); }
bool SchedulerRecovery::Recover() { return true; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_SCHEDULER_RECOVERY_H_
#define MINDSPORE_CCSRC_PS_CORE_SCHEDULER_RECOVERY_H_
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "ps/constants.h"
#include "utils/log_adapter.h"
#include "ps/core/file_configuration.h"
#include "ps/ps_context.h"
#include "ps/core/recovery_base.h"
#include "ps/core/scheduler_node.h"
namespace mindspore {
namespace ps {
namespace core {
// The class helps scheduler node to do recovery operation for the cluster.
class SchedulerRecovery : public RecoveryBase {
public:
SchedulerRecovery() = default;
~SchedulerRecovery() override = default;
// Persist metadata to storage.
void Persist(const std::string &key, const std::string &value);
bool Recover() override;
// Get metadata from storage.
std::string GetMetadata(const std::string &key);
private:
// The node_ will only be instantiated with worker/server node.
SchedulerNode *const node_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_SCHEDULER_RECOVERY_H_

View File

@ -84,6 +84,15 @@ void ServerNode::CreateTcpServer() {
}
void ServerNode::Initialize() {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the server node is failed.";
}
}
InitServerHandler();
CreateTcpServer();
is_already_stopped_ = false;
@ -91,7 +100,6 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "[Server start]: 2. Server node create tcp server successful!";
InitNodeNum();
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Server node connect to scheduler timedout!";

View File

@ -78,5 +78,4 @@ class ServerNode : public AbstractNode {
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_SERVER_NODE_H_

View File

@ -42,13 +42,21 @@ bool WorkerNode::Start(const uint32_t &timeout) {
void WorkerNode::Initialize() {
is_already_stopped_ = false;
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the worker node is failed.";
}
}
InitServerHandler();
CreateTcpServer();
InitNodeInfo(NodeRole::WORKER);
MS_LOG(INFO) << "[Worker start]: 2. Worker node create tcp server successful!";
InitNodeNum();
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Worker node connect to scheduler timeout!";

View File

@ -317,5 +317,9 @@ core::ClusterConfig &PSContext::cluster_config() {
void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
std::string PSContext::config_file_path() const { return config_file_path_; }
} // namespace ps
} // namespace mindspore

View File

@ -148,6 +148,9 @@ class PSContext {
void set_scheduler_manage_port(uint16_t sched_port);
uint16_t scheduler_manage_port() const;
void set_config_file_path(const std::string &path);
std::string config_file_path() const;
private:
PSContext()
: ps_enabled_(false),
@ -176,7 +179,8 @@ class PSContext {
client_learning_rate_(0.001),
secure_aggregation_(false),
cluster_config_(nullptr),
scheduler_manage_port_(0) {}
scheduler_manage_port_(0),
config_file_path_("") {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -238,6 +242,9 @@ class PSContext {
// The port used by scheduler to receive http requests for scale out or scale in.
uint16_t scheduler_manage_port_;
// The path of the configuration file, used to configure the certification path and persistent storage type, etc.
std::string config_file_path_;
};
} // namespace ps
} // namespace mindspore

View File

@ -53,7 +53,8 @@ _set_ps_context_func_map = {
"client_batch_size": ps_context().set_client_batch_size,
"client_learning_rate": ps_context().set_client_learning_rate,
"enable_ps_ssl": ps_context().set_enable_ssl,
"scheduler_manage_port": ps_context().set_scheduler_manage_port
"scheduler_manage_port": ps_context().set_scheduler_manage_port,
"config_file_path": ps_context().set_config_file_path
}
_get_ps_context_func_map = {

View File

@ -0,0 +1,6 @@
{
"recovery": {
"storge_type": 1,
"kStoreFilePath": "recovery.json"
}
}

View File

@ -163,8 +163,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/worker.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parameter_server.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")