!31466 code sync and fix bug

Merge pull request !31466 from tan-wei-cheng-3260/develop-twc-master
This commit is contained in:
i-robot 2022-03-21 06:09:09 +00:00 committed by Gitee
commit a05c9816c3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
31 changed files with 236 additions and 152 deletions

View File

@ -23,10 +23,10 @@ namespace mindspore {
namespace fl {
namespace server {
namespace {
const char *kCollectivePhaseRing = "ring";
const char *kCollectivePhaseGather = "gather";
const char *kCollectivePhaseReduce = "reduce";
const char *kCollectivePhaseBroadcast = "broadcast";
const char kCollectivePhaseRing[] = "ring";
const char kCollectivePhaseGather[] = "gather";
const char kCollectivePhaseReduce[] = "reduce";
const char kCollectivePhaseBroadcast[] = "broadcast";
} // namespace
void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
@ -292,7 +292,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const std::string &data_name, c
}
template <typename T>
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff, size_t send_count) {
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
MS_ERROR_IF_NULL_W_RET_VAL(node_, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
@ -411,7 +411,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
}
template <typename T>
bool CollectiveOpsImpl::AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count) {
bool CollectiveOpsImpl::AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count) {
// The collective communication API does not support calling Send and Recv concurrently with multiple threads;
std::unique_lock<std::mutex> lock(mtx_);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
@ -441,7 +441,7 @@ bool CollectiveOpsImpl::AllReduce(const std::string &data_name, const void *send
}
template <typename T>
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, size_t send_count,
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count,
const std::shared_ptr<ps::core::AbstractNode> &node) {
std::unique_lock<std::mutex> lock(mtx_);
MS_ERROR_IF_NULL_W_RET_VAL(node, false);
@ -476,7 +476,7 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, si
}
template <typename T>
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *const recvbuff, size_t count, uint32_t root,
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
const std::shared_ptr<ps::core::AbstractNode> &node,
const CommunicationGroupInfo &group_info) {
std::unique_lock<std::mutex> lock(mtx_);
@ -515,11 +515,11 @@ bool CollectiveOpsImpl::ReInitForScaling() {
return true;
}
template bool CollectiveOpsImpl::AllReduce<float>(const std::string &data_name, const void *sendbuff, void *recvbuff,
template bool CollectiveOpsImpl::AllReduce<float>(const std::string &data_name, void *sendbuff, void *recvbuff,
size_t count);
template bool CollectiveOpsImpl::AllReduce<size_t>(const std::string &data_name, const void *sendbuff, void *recvbuff,
template bool CollectiveOpsImpl::AllReduce<size_t>(const std::string &data_name, void *sendbuff, void *recvbuff,
size_t count);
template bool CollectiveOpsImpl::AllReduce<int>(const std::string &data_name, const void *sendbuff, void *recvbuff,
template bool CollectiveOpsImpl::AllReduce<int>(const std::string &data_name, void *sendbuff, void *recvbuff,
size_t count);
template bool CollectiveOpsImpl::AllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count,

View File

@ -64,7 +64,7 @@ class CollectiveOpsImpl {
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
template <typename T>
bool AllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);
bool AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count);
template <typename T>
bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count,

View File

@ -242,6 +242,7 @@ constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
constexpr auto kCtxCipherPrimer = "cipher_primer";
constexpr auto kCurrentIteration = "current_iteration";
constexpr auto kInstanceState = "instance_state";
const char PYTHON_MOD_SERIALIZE_MODULE[] = "mindspore.train.serialization";
const char PYTHON_MOD_SAFE_WEIGHT[] = "_save_weight";
@ -295,6 +296,31 @@ enum class ResultCode {
// If there's error happened, return kFail.
kFail
};
inline std::string GetInstanceStateStr(const InstanceState &instance_state) {
switch (instance_state) {
case InstanceState::kRunning:
return "kRunning";
case InstanceState::kFinish:
return "kFinish";
case InstanceState::kDisable:
return "kDisable";
default:
MS_LOG(EXCEPTION) << "InstanceState " << instance_state << " is not supported.";
}
}
inline InstanceState GetInstanceState(const std::string &instance_state) {
if (instance_state == "kRunning") {
return InstanceState::kRunning;
} else if (instance_state == "kFinish") {
return InstanceState::kFinish;
} else if (instance_state == "kDisable") {
return InstanceState::kDisable;
}
MS_LOG(EXCEPTION) << "InstanceState " << instance_state << " is not supported.";
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -165,7 +165,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name, const
std::string reason = "Sending querying whether count reaches " + name +
" threshold message to leader server failed" + (fl_id.empty() ? "" : " for fl id " + fl_id);
MS_LOG(WARNING) << reason;
return false;
return true;
}
MS_ERROR_IF_NULL_W_RET_VAL(query_cnt_enough_rsp_msg, false);

View File

@ -151,7 +151,7 @@ void Iteration::SetIterationRunning() {
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning));
if (server_recovery_ != nullptr) {
// Save data to the persistent storage in case the recovery happens at the beginning.
if (!server_recovery_->Save(iteration_num_)) {
if (!server_recovery_->Save(iteration_num_, instance_state_)) {
MS_LOG(WARNING) << "Save recovery data failed.";
}
}
@ -207,7 +207,8 @@ bool Iteration::ReInitForUpdatingHyperParams(const std::vector<RoundConfig> &upd
for (const auto &round : rounds_) {
if (updated_round.name == round->name()) {
MS_LOG(INFO) << "Reinitialize for round " << round->name();
if (!round->ReInitForUpdatingHyperParams(updated_round.threshold_count, updated_round.time_window)) {
if (!round->ReInitForUpdatingHyperParams(updated_round.threshold_count, updated_round.time_window,
server_node_->server_num())) {
MS_LOG(ERROR) << "Reinitializing for round " << round->name() << " failed.";
return false;
}
@ -735,7 +736,7 @@ void Iteration::EndLastIter() {
if (server_node_->rank_id() == kLeaderServerRank) {
// Save current iteration number for recovery.
MS_ERROR_IF_NULL_WO_RET_VAL(server_recovery_);
if (!server_recovery_->Save(iteration_num_)) {
if (!server_recovery_->Save(iteration_num_, instance_state_)) {
MS_LOG(WARNING) << "Can't save current iteration number into persistent storage.";
}
}
@ -748,7 +749,11 @@ void Iteration::EndLastIter() {
set_loss(0.0f);
Server::GetInstance().CancelSafeMode();
iteration_state_cv_.notify_all();
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
if (iteration_num_ > ps::PSContext::instance()->fl_iteration_num()) {
MS_LOG(WARNING) << "The server's training job is finished.";
} else {
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
}
}
bool Iteration::ForciblyMoveToNextIteration() {
@ -914,6 +919,11 @@ void Iteration::UpdateRoundClientUploadLoss(const std::shared_ptr<std::vector<un
set_loss(loss_ + end_last_iter_rsp.upload_loss());
}
void Iteration::set_instance_state(InstanceState state) {
instance_state_ = state;
MS_LOG(INFO) << "Server instance state is " << GetInstanceStateStr(instance_state_);
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -132,6 +132,8 @@ class Iteration {
// The round kernels whose Launch method has not returned yet.
std::atomic_uint32_t running_round_num_;
void set_instance_state(InstanceState staet);
private:
Iteration()
: running_round_num_(0),

View File

@ -28,9 +28,8 @@ bool IterationMetrics::Initialize() {
config_ = std::make_unique<ps::core::FileConfiguration>(config_file_path_);
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(WARNING) << "Initializing for metrics failed. Config file path " << config_file_path_
<< " may be invalid or not exist.";
return false;
MS_LOG(EXCEPTION) << "Initializing for metrics failed. Config file path " << config_file_path_
<< " may be invalid or not exist.";
}
// Read the metrics file path. If file is not set or not exits, create one.

View File

@ -57,17 +57,7 @@ void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_tim
void RoundKernel::SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
size_t len) {
if (message == nullptr) {
MS_LOG(WARNING) << "The message handler is nullptr.";
return;
}
if (data == nullptr || len == 0) {
std::string reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
if (!message->SendResponse(reason.c_str(), reason.size())) {
MS_LOG(WARNING) << "Sending response failed.";
return;
}
if (!verifyResponse(message, data, len)) {
return;
}
IncreaseTotalClientNum();
@ -79,17 +69,7 @@ void RoundKernel::SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler
void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
size_t len, ps::core::RefBufferRelCallback cb) {
if (message == nullptr) {
MS_LOG(WARNING) << "The message handler is nullptr.";
return;
}
if (data == nullptr || len == 0) {
std::string reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
if (!message->SendResponse(reason.c_str(), reason.size())) {
MS_LOG(WARNING) << "Sending response failed.";
return;
}
if (!verifyResponse(message, data, len)) {
return;
}
IncreaseTotalClientNum();
@ -99,6 +79,23 @@ void RoundKernel::SendResponseMsgInference(const std::shared_ptr<ps::core::Messa
}
}
bool RoundKernel::verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data,
size_t len) {
if (message == nullptr) {
MS_LOG(WARNING) << "The message handler is nullptr.";
return false;
}
if (data == nullptr || len == 0) {
std::string reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
if (!message->SendResponse(reason.c_str(), reason.size())) {
MS_LOG(WARNING) << "Sending response failed.";
}
return false;
}
return true;
}
void RoundKernel::IncreaseTotalClientNum() { total_client_num_ += 1; }
void RoundKernel::IncreaseAcceptClientNum() { accept_client_num_ += 1; }

View File

@ -100,6 +100,8 @@ class RoundKernel {
float upload_loss() const;
bool verifyResponse(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);
protected:
// Send response to client, and the data can be released after the call.
void SendResponseMsg(const std::shared_ptr<ps::core::MessageHandler> &message, const void *data, size_t len);

View File

@ -41,6 +41,10 @@ const size_t LocalMetaStore::curr_iter_num() {
return curr_iter_num_;
}
void LocalMetaStore::set_curr_instance_state(InstanceState instance_state) { instance_state_ = instance_state; }
const InstanceState LocalMetaStore::curr_instance_state() { return instance_state_; }
const void LocalMetaStore::put_aggregation_feature_map(const std::string &name, const Feature &feature) {
if (aggregation_feature_map_.count(name) > 0) {
MS_LOG(WARNING) << "Put feature " << name << " failed.";

View File

@ -79,6 +79,10 @@ class LocalMetaStore {
void set_curr_iter_num(size_t num);
const size_t curr_iter_num();
void set_curr_instance_state(InstanceState instance_state);
const InstanceState curr_instance_state();
const void put_aggregation_feature_map(const std::string &name, const Feature &feature);
std::unordered_map<std::string, Feature> &aggregation_feature_map();
@ -96,7 +100,7 @@ class LocalMetaStore {
// This mutex makes sure that the operations on key_to_meta_ is threadsafe.
std::mutex mtx_;
size_t curr_iter_num_{0};
InstanceState instance_state_;
// aggregation_feature_map_ stores model meta data with weight name and size which will be Aggregated.
std::unordered_map<std::string, Feature> aggregation_feature_map_;
};

View File

@ -94,7 +94,8 @@ bool Round::ReInitForScaling(uint32_t server_num) {
return true;
}
bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window) {
bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window,
uint32_t server_num) {
time_window_ = updated_time_window;
threshold_count_ = updated_threshold_count;
if (check_count_) {
@ -105,7 +106,11 @@ bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t
}
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
kernel_->InitKernel(threshold_count_);
if (name_ == "reconstructSecrets") {
kernel_->InitKernel(server_num);
} else {
kernel_->InitKernel(threshold_count_);
}
return true;
}
@ -189,7 +194,11 @@ bool Round::IsServerAvailable(std::string *reason) {
// If the server state is Disable or Finish, refuse the request.
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
if (kPrintTimes % kPrintTimesThreshold == 0) {
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
kPrintTimes = 0;
}
kPrintTimes += 1;
*reason = ps::kJobNotAvailable;
return false;
}

View File

@ -44,7 +44,7 @@ class Round {
bool ReInitForScaling(uint32_t server_num);
// After hyper-parameters are updated, some rounds and kernels should be reinitialized.
bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window);
bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window, uint32_t server_num);
// Bind a round kernel to this Round. This method should be called after Initialize.
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);

View File

@ -515,6 +515,7 @@ void Server::Recover() {
// Set the recovery handler to Iteration.
MS_EXCEPTION_IF_NULL(iteration_);
iteration_->set_recovery_handler(server_recovery_);
iteration_->set_instance_state(LocalMetaStore::GetInstance().curr_instance_state());
}
void Server::ProcessBeforeScalingOut() {

View File

@ -75,13 +75,18 @@ bool ServerRecovery::Recover() {
return false;
}
uint64_t current_iter = JsonGetKeyWithException<uint64_t>(server_recovery_json, kCurrentIteration);
std::string instance_state = JsonGetKeyWithException<std::string>(server_recovery_json, kInstanceState);
LocalMetaStore::GetInstance().set_curr_iter_num(current_iter);
MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter;
LocalMetaStore::GetInstance().set_curr_instance_state(GetInstanceState(instance_state));
MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter
<< ", instance state is " << instance_state;
server_recovery_file_.close();
return true;
}
bool ServerRecovery::Save(uint64_t current_iter) {
bool ServerRecovery::Save(uint64_t current_iter, InstanceState instance_state) {
std::unique_lock<std::mutex> lock(server_recovery_file_mtx_);
server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate);
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
@ -92,6 +97,7 @@ bool ServerRecovery::Save(uint64_t current_iter) {
nlohmann::json server_metadata_json;
server_metadata_json[kCurrentIteration] = current_iter;
server_metadata_json[kInstanceState] = GetInstanceStateStr(instance_state);
server_recovery_file_ << server_metadata_json;
server_recovery_file_.close();
return true;

View File

@ -23,6 +23,7 @@
#include <string>
#include <vector>
#include <mutex>
#include "fl/server/common.h"
#include "ps/core/recovery_base.h"
#include "ps/core/file_configuration.h"
#include "ps/core/communicator/tcp_communicator.h"
@ -45,7 +46,7 @@ class ServerRecovery : public ps::core::RecoveryBase {
bool Recover() override;
// Save server's metadata to persistent storage.
bool Save(uint64_t current_iter);
bool Save(uint64_t current_iter, InstanceState instance_state);
// If this server recovers, need to notify cluster to reach consistency.
bool SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator, uint32_t rank_id);

View File

@ -693,12 +693,7 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
HeartbeatRespMessage heartbeat_resp_message;
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
if (heartbeat_resp_message.cluster_state() != current_cluster_state_) {
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
}
current_cluster_state_ = heartbeat_resp_message.cluster_state();
UpdateClusterState(heartbeat_resp_message.cluster_state());
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
<< CommUtil::ClusterStateToString(current_cluster_state_);
@ -1338,11 +1333,10 @@ bool AbstractNode::Recover() {
MS_LOG(INFO) << "The node is support recovery.";
node_recovery_ = std::make_unique<NodeRecovery>(this);
MS_EXCEPTION_IF_NULL(node_recovery_);
if (!node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
MS_LOG(ERROR) << "Initializing node recovery failed.";
return false;
if (node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
MS_LOG(INFO) << "Initializing node recovery success.";
return node_recovery_->Recover();
}
return node_recovery_->Recover();
}
return false;
}
@ -1419,8 +1413,16 @@ void AbstractNode::CreateTcpServer() {
void AbstractNode::UpdateClusterState(const ClusterState &state) {
std::lock_guard<std::mutex> lock(cluster_state_mutex_);
std::string state_str = CommUtil::ClusterStateToString(state);
if (state_str.empty()) {
return;
}
if (state == current_cluster_state_) {
return;
}
MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
<< " to " << CommUtil::ClusterStateToString(state);
<< " to " << state_str;
current_cluster_state_ = state;
}

View File

@ -218,15 +218,9 @@ void CommUtil::LogCallback(int severity, const char *msg) {
}
}
bool CommUtil::IsFileExists(const std::string &file) {
std::ifstream f(file.c_str());
if (!f.good()) {
return false;
} else {
f.close();
return true;
}
}
bool CommUtil::IsFileExists(const std::string &file) { return access(file.c_str(), F_OK) != -1; }
bool CommUtil::IsFileReadable(const std::string &file) { return access(file.c_str(), R_OK) != -1; }
bool CommUtil::IsFileEmpty(const std::string &file) {
if (!IsFileExists(file)) {

View File

@ -136,6 +136,7 @@ class CommUtil {
static bool checkCRLTime(const std::string &crlPath);
static bool CreateDirectory(const std::string &directoryPath);
static bool CheckHttpUrl(const std::string &http_url);
static bool IsFileReadable(const std::string &file);
private:
static std::random_device rd;

View File

@ -69,6 +69,8 @@ class Configuration {
// storage meta data without nodes
virtual void PersistNodes(const core::ClusterConfig &clusterConfig) const = 0;
virtual std::string file_path() const = 0;
};
} // namespace core
} // namespace ps

View File

@ -29,9 +29,12 @@ bool FileConfiguration::Initialize() {
return false;
}
if (!CommUtil::IsFileReadable(file_path_)) {
MS_LOG(EXCEPTION) << "The file path: " << file_path_ << " is not readable.";
}
if (CommUtil::IsFileEmpty(file_path_)) {
MS_LOG(INFO) << "The file: " << file_path_ << " is empty.";
return true;
MS_LOG(EXCEPTION) << "The file path: " << file_path_ << " content is empty.";
}
std::ifstream json_file(file_path_);
@ -142,6 +145,8 @@ void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) co
output_file.close();
MS_LOG(INFO) << "The meta data persist to " << file_path_;
}
std::string FileConfiguration::file_path() const { return file_path_; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -89,6 +89,8 @@ class BACKEND_EXPORT FileConfiguration : public Configuration {
void PersistNodes(const core::ClusterConfig &clusterConfig) const override;
std::string file_path() const override;
private:
// The path of the configuration file.
std::string file_path_;

View File

@ -45,8 +45,30 @@ uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message
<< ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_);
return rank_id;
}
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
std::unordered_map<std::string, NodeInfo> recovery_node_infos = clusterConfig.initial_registered_nodes_infos;
// This is for scheduler recovery
(void)ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port(), &rank_id);
if (PSContext::instance()->server_mode() == kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) {
if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) {
const std::string &new_ip = register_message.ip();
uint32_t new_port = register_message.port();
rank_id = recovery_node_infos[node_id].rank_id_;
recovery_node_infos[node_id].is_alive = true;
recovery_node_infos[node_id].ip_ = new_ip;
recovery_node_infos[node_id].port_ = static_cast<uint16_t>(new_port);
registered_nodes_info_[node_id] = recovery_node_infos[node_id];
MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!"
<< ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_
<< ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive
<< ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_);
return rank_id;
}
} else {
ReAddNodeIfNotExists(node_id, register_message.ip(), register_message.port());
}
return rank_id;
}
@ -90,24 +112,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
auto rank_it = std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::SERVER;
if (res) {
MS_LOG(INFO) << "The server node id:" << item.first << " rank id:" << item.second.rank_id_ << " is not alive.";
rank_id = item.second.rank_id_;
}
return res;
});
if (rank_it == registered_nodes_info_.end()) {
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_server_rank_id_) {
rank_id = meta->rank_id();
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
} else {
rank_id = next_server_rank_id_;
next_server_rank_id_ += 1;
}
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_server_rank_id_) {
rank_id = meta->rank_id();
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
} else {
registered_nodes_info_.erase((*rank_it).first);
rank_id = next_server_rank_id_;
next_server_rank_id_ += 1;
}
if (rank_id >= meta_data_->server_num) {
@ -131,25 +141,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
auto worker_rank_it =
std::find_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [&rank_id](auto item) {
bool res = item.second.is_alive == false && item.second.node_role_ == NodeRole::WORKER;
if (res) {
MS_LOG(INFO) << "The worker node id:" << item.first << " rank id:" << rank_id << " is not alive.";
rank_id = item.second.rank_id_;
}
return res;
});
if (worker_rank_it == registered_nodes_info_.end()) {
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_worker_rank_id_) {
rank_id = meta->rank_id();
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
} else {
rank_id = next_worker_rank_id_;
next_worker_rank_id_ += 1;
}
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_worker_rank_id_) {
rank_id = meta->rank_id();
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
} else {
registered_nodes_info_.erase((*worker_rank_it).first);
rank_id = next_worker_rank_id_;
next_worker_rank_id_ += 1;
}
if (rank_id >= meta_data_->worker_num) {
@ -265,12 +262,19 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
bool NodeManager::IsAllNodesRegistered() const {
uint32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
[](auto item) { return item.second.is_alive == true; });
bool NodeManager::IsAllNodesAlive() const {
uint32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(), [](auto item) {
if (!item.second.is_alive) {
MS_LOG(ERROR) << item.second.node_id_ << " is not alive.";
return false;
}
return true;
});
return num == total_node_num_;
}
bool NodeManager::IsAllNodesRegistered() const { return SizeToUint(registered_nodes_info_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesFinished() const { return SizeToUint(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleOutDone() const {

View File

@ -137,6 +137,8 @@ class NodeManager {
// Determine whether all nodes that need to be persisted are in persistence.
bool IsAllNodeInPersisting();
bool IsAllNodesAlive() const;
private:
std::mutex node_mutex_;
std::mutex cluster_mutex_;

View File

@ -31,7 +31,7 @@ bool NodeRecovery::Recover() {
int32_t worker_num = std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase);
node_->set_worker_num(worker_num);
} else {
node_->set_worker_num(PSContext::instance()->cluster_config().initial_worker_num);
MS_LOG(EXCEPTION) << kRecoveryWorkerNum << " is not contained in " << recovery_storage_->file_path();
}
// 2. recover server num
@ -39,14 +39,14 @@ bool NodeRecovery::Recover() {
int32_t server_num = std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase);
node_->set_server_num(server_num);
} else {
node_->set_server_num(PSContext::instance()->cluster_config().initial_server_num);
MS_LOG(EXCEPTION) << kRecoveryServerNum << " is not contained in " << recovery_storage_->file_path();
}
// 3. recover scheduler ip
if (recovery_storage_->Exists(kRecoverySchedulerIp)) {
node_->set_scheduler_ip(recovery_storage_->GetString(kRecoverySchedulerIp, ""));
} else {
node_->set_scheduler_ip(PSContext::instance()->cluster_config().scheduler_host);
MS_LOG(EXCEPTION) << kRecoverySchedulerIp << " is not contained in " << recovery_storage_->file_path();
}
// 4. recover scheduler port
@ -54,7 +54,7 @@ bool NodeRecovery::Recover() {
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase);
node_->set_scheduler_port(scheduler_port);
} else {
node_->set_scheduler_port(PSContext::instance()->cluster_config().scheduler_port);
MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path();
}
MS_LOG(INFO) << "The worker num:" << node_->worker_num() << ", the server num:" << node_->server_num()
<< ", the scheduler ip:" << node_->scheduler_ip() << ", the scheduler port:" << node_->scheduler_port();

View File

@ -50,8 +50,10 @@ bool RecoveryBase::Initialize(const std::string &config_json) {
}
recovery_storage_ = std::make_unique<FileConfiguration>(storage_file_path);
MS_EXCEPTION_IF_NULL(recovery_storage_);
if (!recovery_storage_->Initialize()) {
MS_LOG(WARNING) << "The storage file path " << storage_file_path << " is empty.";
if (recovery_storage_->Initialize()) {
MS_LOG(INFO) << "The storage file path " << storage_file_path << " initialize success.";
} else {
return false;
}
}
@ -80,8 +82,10 @@ bool RecoveryBase::InitializeNodes(const std::string &config_json) {
}
scheduler_recovery_storage_ = std::make_unique<FileConfiguration>(scheduler_storage_file_path);
MS_EXCEPTION_IF_NULL(scheduler_recovery_storage_);
if (!scheduler_recovery_storage_->Initialize()) {
MS_LOG(WARNING) << "The scheduler storage file path " << scheduler_storage_file_path << " is empty.";
if (scheduler_recovery_storage_->Initialize()) {
MS_LOG(INFO) << "The scheduler storage file path " << scheduler_storage_file_path << " initialize success.";
} else {
return false;
}
MS_LOG(INFO) << "the scheduler storage file path is:" << scheduler_storage_file_path;

View File

@ -31,12 +31,12 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "[Scheduler start]: 1. Begin to start scheduler node!";
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
InitNodeMetaData();
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeMetaData();
MS_LOG(WARNING) << "The config file is empty.";
} else {
if (!RecoverScheduler()) {
MS_LOG(WARNING) << "Recover the server node is failed.";
MS_LOG(DEBUG) << "Recover the server node is failed.";
}
}
@ -338,6 +338,10 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
SetRegisterConnectionFd(conn, node_id);
if (node_manager_.IsAllNodesRegistered()) {
if (!node_manager_.IsAllNodesAlive()) {
MS_LOG(ERROR) << "Do not broadcast nodes info because some server nodes are not alive.";
return;
}
is_ready_ = true;
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
@ -1422,17 +1426,22 @@ bool SchedulerNode::RecoverScheduler() {
MS_LOG(INFO) << "The scheduler node is support recovery.";
scheduler_recovery_ = std::make_unique<SchedulerRecovery>();
MS_EXCEPTION_IF_NULL(scheduler_recovery_);
(void)scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
(void)scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, ""));
return scheduler_recovery_->Recover();
bool ret = scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
bool ret_node = scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, ""));
if (ret && ret_node) {
MS_LOG(INFO) << "Scheduler recovery initialize successful.";
return scheduler_recovery_->Recover();
}
}
return false;
}
void SchedulerNode::PersistMetaData() {
if (scheduler_recovery_ == nullptr) {
MS_LOG(WARNING) << "scheduler recovery is null, so don't persist meta data";
MS_LOG(WARNING) << "scheduler recovery is null, do not persist meta data";
return;
}
if (!is_ready_) {
return;
}
if (config_->Exists(kKeyRecovery)) {

View File

@ -27,9 +27,7 @@ std::string SchedulerRecovery::GetMetadata(const std::string &key) {
bool SchedulerRecovery::Recover() {
std::unique_lock<std::mutex> lock(recovery_mtx_);
if (recovery_storage_ == nullptr) {
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(recovery_storage_, false);
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
// 1. recover worker num
@ -38,7 +36,7 @@ bool SchedulerRecovery::Recover() {
UlongToUint(std::strtoul(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase));
clusterConfig.initial_worker_num = initial_worker_num;
} else {
clusterConfig.initial_worker_num = PSContext::instance()->initial_worker_num();
MS_LOG(EXCEPTION) << kRecoveryWorkerNum << " is not contained in " << recovery_storage_->file_path();
}
// 2. recover server num
@ -47,14 +45,14 @@ bool SchedulerRecovery::Recover() {
UlongToUint(std::strtoul(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase));
clusterConfig.initial_server_num = initial_server_num;
} else {
clusterConfig.initial_server_num = PSContext::instance()->initial_server_num();
MS_LOG(EXCEPTION) << kRecoveryServerNum << " is not contained in " << recovery_storage_->file_path();
}
// 3. recover scheduler ip
if (recovery_storage_->Exists(kRecoverySchedulerIp)) {
clusterConfig.scheduler_host = recovery_storage_->GetString(kRecoverySchedulerIp, "");
} else {
clusterConfig.scheduler_host = PSContext::instance()->scheduler_host();
MS_LOG(EXCEPTION) << kRecoverySchedulerIp << " is not contained in " << recovery_storage_->file_path();
}
// 4. recover scheduler port
@ -62,7 +60,7 @@ bool SchedulerRecovery::Recover() {
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase);
clusterConfig.scheduler_port = scheduler_port;
} else {
clusterConfig.scheduler_port = PSContext::instance()->scheduler_port();
MS_LOG(EXCEPTION) << kRecoverySchedulerPort << " is not contained in " << recovery_storage_->file_path();
}
MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num
@ -70,15 +68,14 @@ bool SchedulerRecovery::Recover() {
<< ", the scheduler ip:" << clusterConfig.scheduler_host
<< ", the scheduler port:" << clusterConfig.scheduler_port;
if (scheduler_recovery_storage_ == nullptr) {
MS_LOG(WARNING) << "scheduler recovery storage is null. return false";
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(scheduler_recovery_storage_, false);
// 5. recover total node num
if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) {
uint32_t initial_total_node_num =
UlongToUint(std::strtoul(scheduler_recovery_storage_->Get(kRecoveryTotalNodeNum, "").c_str(), nullptr, kBase));
clusterConfig.initial_total_node_num = initial_total_node_num;
} else {
MS_LOG(EXCEPTION) << kRecoveryTotalNodeNum << " is not contained in " << recovery_storage_->file_path();
}
// 6. recover next worker rank id
@ -86,6 +83,8 @@ bool SchedulerRecovery::Recover() {
uint32_t initial_next_worker_rank_id = UlongToUint(
std::strtoul(scheduler_recovery_storage_->Get(kRecoveryNextWorkerRankId, "").c_str(), nullptr, kBase));
clusterConfig.initial_next_worker_rank_id = initial_next_worker_rank_id;
} else {
MS_LOG(EXCEPTION) << kRecoveryNextWorkerRankId << " is not contained in " << recovery_storage_->file_path();
}
// 7. recover next server rank id
@ -93,6 +92,8 @@ bool SchedulerRecovery::Recover() {
uint32_t initial_next_server_rank_id = UlongToUint(
std::strtoul(scheduler_recovery_storage_->Get(kRecoveryNextServerRankId, "").c_str(), nullptr, kBase));
clusterConfig.initial_next_server_rank_id = initial_next_server_rank_id;
} else {
MS_LOG(EXCEPTION) << kRecoveryNextServerRankId << " is not contained in " << recovery_storage_->file_path();
}
// 8. recover register nodes info
@ -122,13 +123,11 @@ bool SchedulerRecovery::Recover() {
<< recovery_server_num << " initial server num is:" << clusterConfig.initial_server_num;
}
clusterConfig.initial_registered_nodes_infos = nodes_infos;
} else {
MS_LOG(EXCEPTION) << kRecoveryRegisteredNodesInfos << " is not contained in " << recovery_storage_->file_path();
}
MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num
<< ", the server num:" << clusterConfig.initial_server_num
<< ", the scheduler ip:" << clusterConfig.scheduler_host
<< ", the scheduler port:" << clusterConfig.scheduler_port
<< ", the initial total node num:" << clusterConfig.initial_total_node_num
MS_LOG(INFO) << ", the initial total node num:" << clusterConfig.initial_total_node_num
<< ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id
<< ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id;

View File

@ -42,12 +42,12 @@ bool ServerNode::Start(const uint32_t &timeout) {
void ServerNode::Initialize() {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
InitNodeNum();
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
MS_LOG(WARNING) << "The config file is empty.";
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the server node is failed.";
MS_LOG(DEBUG) << "Recover the server node is failed.";
}
}
InitServerHandler();

View File

@ -42,12 +42,12 @@ bool WorkerNode::Start(const uint32_t &timeout) {
void WorkerNode::Initialize() {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
InitNodeNum();
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
InitNodeNum();
MS_LOG(WARNING) << "The config file is empty.";
} else {
if (!Recover()) {
MS_LOG(WARNING) << "Recover the worker node is failed.";
MS_LOG(DEBUG) << "Recover the worker node is failed.";
}
}
InitServerHandler();

View File

@ -209,9 +209,8 @@ void PSContext::set_rank_id(uint32_t rank_id) const {
}
void PSContext::set_server_mode(const std::string &server_mode) {
if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
<< " or " << kServerModeHybrid;
if (server_mode != kServerModePS && server_mode != kServerModeFL) {
MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL;
return;
}
MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";