Merge pull request !28882 from tan-wei-cheng-3260/r1.6
This commit is contained in:
i-robot 2022-01-12 07:03:03 +00:00 committed by Gitee
commit b0cf019553
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 221 additions and 200 deletions

View File

@ -439,13 +439,13 @@ void AbstractNode::RegisterFollowerScalerHandlerAfterScaleIn(const std::string &
PersistentState AbstractNode::persistent_state() const { return persistent_state_; }
void AbstractNode::set_persistent_state(PersistentState persistent_state) { persistent_state_ = persistent_state; }
int32_t AbstractNode::worker_num() const { return worker_num_; }
uint32_t AbstractNode::worker_num() const { return worker_num_; }
int32_t AbstractNode::server_num() const { return server_num_; }
uint32_t AbstractNode::server_num() const { return server_num_; }
void AbstractNode::set_worker_num(const int32_t &worker_num) { worker_num_ = worker_num; }
void AbstractNode::set_worker_num(const uint32_t &worker_num) { worker_num_ = worker_num; }
void AbstractNode::set_server_num(const int32_t &server_num) { server_num_ = server_num; }
void AbstractNode::set_server_num(const uint32_t &server_num) { server_num_ = server_num; }
std::string AbstractNode::scheduler_ip() const { return scheduler_ip_; }
@ -1063,7 +1063,8 @@ void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, c
std::shared_ptr<unsigned char> res(new unsigned char[size], std::default_delete<unsigned char[]>());
#else
if (size < 0) {
MS_LOG(EXCEPTION) << "size < 0";
MS_LOG(ERROR) << "size < 0";
return;
}
std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
#endif
@ -1293,7 +1294,7 @@ void AbstractNode::CreateTcpServer() {
}
void AbstractNode::UpdateClusterState(const ClusterState &state) {
std::lock_guard<std::mutex> lk(cluster_state_mutex_);
std::lock_guard<std::mutex> lock(cluster_state_mutex_);
MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
<< " to " << CommUtil::ClusterStateToString(state);
current_cluster_state_ = state;

View File

@ -47,8 +47,8 @@ class AbstractNode : public Node {
client_to_server_(nullptr),
server_(nullptr),
server_thread_(nullptr),
worker_num_(-1),
server_num_(-1),
worker_num_(0),
server_num_(0),
is_connected_to_scheduler_(false),
is_current_node_scale_in_(false),
follower_scaler_(nullptr),
@ -125,11 +125,11 @@ class AbstractNode : public Node {
PersistentState persistent_state() const;
void set_persistent_state(PersistentState persistent_state);
int32_t worker_num() const;
int32_t server_num() const;
uint32_t worker_num() const;
uint32_t server_num() const;
void set_worker_num(const int32_t &worker_num);
void set_server_num(const int32_t &server_num);
void set_worker_num(const uint32_t &worker_num);
void set_server_num(const uint32_t &server_num);
std::string scheduler_ip() const;
void set_scheduler_ip(const std::string &scheduler_ip);
@ -264,8 +264,8 @@ class AbstractNode : public Node {
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
int32_t worker_num_;
int32_t server_num_;
uint32_t worker_num_;
uint32_t server_num_;
std::atomic<bool> is_connected_to_scheduler_;
// Identify whether the current node is a scale in node.
std::atomic<bool> is_current_node_scale_in_;

View File

@ -61,7 +61,7 @@ struct ClusterConfig {
uint16_t scheduler_port;
// The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds.
uint32_t heartbeat_timeout;
// Timeout period for cluster preparation is 300 seconds.
// Timeout period for cluster preparation is 900 seconds.
uint32_t cluster_available_timeout;
// The timeout period for the client to connect to the server is 3000ms.
uint32_t connect_interval;

View File

@ -253,7 +253,6 @@ bool CommUtil::CreateDirectory(const std::string &directoryPath) {
}
std::string CommUtil::ClusterStateToString(const ClusterState &state) {
MS_LOG(DEBUG) << "The cluster state:" << state;
if (state < SizeToInt(kClusterState.size())) {
return kClusterState.at(state);
} else {
@ -430,11 +429,11 @@ bool CommUtil::verifyCertKeyID(const X509 *caCert, const X509 *subCert) {
ASN1_OCTET_STRING *skid =
reinterpret_cast<ASN1_OCTET_STRING *>(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL));
MS_EXCEPTION_IF_NULL(skid);
const int keyidLen = 512;
const size_t keyidLen = 512;
char subject_keyid[keyidLen] = {0};
for (int i = 0; i < skid->length; i++) {
char keyid[8] = {0};
int base = keyidLen;
size_t base = keyidLen;
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]);
int ret = strcat_s(subject_keyid, base, keyid);
if (ret == -1) {
@ -449,7 +448,7 @@ bool CommUtil::verifyCertKeyID(const X509 *caCert, const X509 *subCert) {
char issuer_keyid[keyidLen] = {0};
for (int i = 0; i < akeyid->keyid->length; i++) {
char keyid[8] = {0};
int base = keyidLen;
size_t base = keyidLen;
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i]));
int ret = strcat_s(issuer_keyid, base, keyid);
if (ret == -1) {

View File

@ -43,7 +43,7 @@ void CommunicatorBase::Join() {
return;
}
bool CommunicatorBase::running() { return running_; }
bool CommunicatorBase::running() const { return running_; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -85,7 +85,7 @@ class CommunicatorBase {
bool SendResponse(const void *rsp_data, size_t rsp_len, const std::shared_ptr<MessageHandler> &msg_handler);
bool running();
bool running() const;
protected:
std::unordered_map<std::string, MessageCallback> msg_callbacks_;

View File

@ -68,7 +68,7 @@ bool HttpServer::InitServer() {
return false;
}
fd_ = ::socket(AF_INET, SOCK_STREAM, 0);
fd_ = ::socket(static_cast<int>(AF_INET), static_cast<int>(SOCK_STREAM), 0);
if (fd_ < 0) {
MS_LOG(ERROR) << "Socker error!";
return false;

View File

@ -159,7 +159,7 @@ void TcpClient::Stop() {
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
const int one = 1;
int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
if (ret < 0) {
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
}
@ -193,10 +193,10 @@ void TcpClient::ReadCallbackInner(struct bufferevent *bev, void *const ctx) {
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
char read_buffer[kMessageChunkLength];
int read = 0;
size_t read = 0;
while ((read = bufferevent_read(bev, &read_buffer, SizeToInt(sizeof(read_buffer)))) > 0) {
tcp_client->OnReadHandler(read_buffer, IntToSize(read));
while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
tcp_client->OnReadHandler(read_buffer, read);
}
}

View File

@ -21,7 +21,7 @@ namespace mindspore {
namespace ps {
namespace core {
TcpMsgHandler::TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, DataPtr data, size_t size)
const std::shared_ptr<MessageMeta> &meta, const DataPtr data, size_t size)
: abstract_node_(abstract_node), tcp_conn_(conn), meta_(meta), data_ptr_(data), data_(nullptr), len_(size) {
if (data_ptr_ != nullptr) {
data_ = data_ptr_.get();

View File

@ -29,7 +29,7 @@ namespace core {
class TcpMsgHandler : public MessageHandler {
public:
TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, DataPtr data, size_t size);
const std::shared_ptr<MessageMeta> &meta, const DataPtr data, size_t size);
~TcpMsgHandler() override = default;
void *data() const override;

View File

@ -28,8 +28,7 @@ void NodeManager::InitNode() {
total_node_num_ = initial_total_node_num_;
}
uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message,
const std::shared_ptr<MessageMeta> &meta) {
uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage &register_message) {
uint32_t rank_id = UINT_MAX;
const std::string &node_id = register_message.node_id();
if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) {
@ -70,11 +69,11 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(meta_data_);
std::lock_guard<std::mutex> lock(assign_rank_id_mutex_);
uint32_t rank_id = checkIfRankIdExist(register_message, meta);
uint32_t rank_id = checkIfRankIdExist(register_message);
if (rank_id != UINT_MAX) {
return rank_id;
}
if (total_node_num_ == SizeToInt(registered_nodes_info_.size())) {
if (total_node_num_ == SizeToUint(registered_nodes_info_.size())) {
MS_LOG(WARNING) << "There are enough nodes registering to scheduler.";
return UINT_MAX;
}
@ -94,11 +93,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
return res;
});
if (rank_it == registered_nodes_info_.end()) {
if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_server_rank_id_) {
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_server_rank_id_) {
rank_id = meta->rank_id();
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
} else {
rank_id = ++next_server_rank_id_;
rank_id = next_server_rank_id_;
++next_server_rank_id_;
}
} else {
registered_nodes_info_.erase((*rank_it).first);
@ -134,11 +134,12 @@ uint32_t NodeManager::NextRankId(const RegisterMessage &register_message, const
return res;
});
if (worker_rank_it == registered_nodes_info_.end()) {
if (meta->rank_id() != UINT32_MAX && UintToInt(meta->rank_id()) <= next_worker_rank_id_) {
if (meta->rank_id() != UINT32_MAX && meta->rank_id() < next_worker_rank_id_) {
rank_id = meta->rank_id();
MS_LOG(INFO) << "Use the old rank id:" << rank_id;
} else {
rank_id = ++next_worker_rank_id_;
rank_id = next_worker_rank_id_;
++next_worker_rank_id_;
}
} else {
registered_nodes_info_.erase((*worker_rank_it).first);
@ -224,7 +225,7 @@ void NodeManager::UpdateCluster() {
if (onPersist) {
onPersist();
}
} else if (SizeToInt(heartbeats_.size()) == total_node_num_) {
} else if (SizeToUint(heartbeats_.size()) == total_node_num_) {
if (cluster_state_ == ClusterState::NODE_TIMEOUT) {
for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) {
if (registered_nodes_info_.count(it->first)) {
@ -239,23 +240,11 @@ void NodeManager::UpdateCluster() {
}
// 2. update cluster finish state
if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
if (SizeToUint(finish_nodes_id_.size()) == total_node_num_) {
UpdateClusterState(ClusterState::CLUSTER_EXIT);
}
}
void NodeManager::CheckClusterTimeout() {
if (total_node_num_ != SizeToInt(registered_nodes_info_.size())) {
MS_LOG(WARNING) << "The cluster is not ready after "
<< PSContext::instance()->cluster_config().cluster_available_timeout
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
<< registered_nodes_info_.size();
current_node_num_ = SizeToInt(registered_nodes_info_.size());
UpdateClusterState(ClusterState::NODE_TIMEOUT);
}
}
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_done_nodes_id_.insert(node_id); }
@ -263,18 +252,20 @@ void NodeManager::AddScaleOutDoneNode(const std::string &node_id) { scale_out_do
void NodeManager::AddScaleInDoneNode(const std::string &node_id) { scale_in_done_nodes_id_.insert(node_id); }
bool NodeManager::IsAllNodesRegistered() const {
int32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
[](auto item) { return item.second.is_alive == true; });
uint32_t num = std::count_if(registered_nodes_info_.begin(), registered_nodes_info_.end(),
[](auto item) { return item.second.is_alive == true; });
return num == total_node_num_;
}
bool NodeManager::IsAllNodesFinished() const { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesFinished() const { return SizeToUint(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleOutDone() const {
return SizeToInt(scale_out_done_nodes_id_.size()) == total_node_num_;
return SizeToUint(scale_out_done_nodes_id_.size()) == total_node_num_;
}
bool NodeManager::IsAllNodesScaleInDone() const { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; }
bool NodeManager::IsAllNodesScaleInDone() const {
return SizeToUint(scale_in_done_nodes_id_.size()) == total_node_num_;
}
const std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() const { return nodes_info_; }
@ -362,7 +353,6 @@ bool NodeManager::IsNodeRegistered(const std::string &node_id) {
const NodeInfo NodeManager::QueryNodeInfo(const std::string &node_id) const {
auto iter = registered_nodes_info_.find(node_id);
if (iter == registered_nodes_info_.end()) {
MS_LOG(DEBUG) << "Cannot find node of id: " << node_id;
return NodeInfo();
}
return iter->second;
@ -376,33 +366,33 @@ void NodeManager::AddPersistingNode(const std::string &node_id) { nodes_persisti
bool NodeManager::IsAllNodeInPersisting() {
// The worker role does not support disaster recovery currently.
if (nodes_persisting_.size() == IntToSize(server_num())) {
if (SizeToUint(nodes_persisting_.size()) == server_num()) {
nodes_persisting_.clear();
return true;
}
return false;
}
void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
void NodeManager::set_total_node_num(const uint32_t &node_num) { total_node_num_ = node_num; }
const int32_t &NodeManager::total_node_num() const { return total_node_num_; }
const uint32_t &NodeManager::total_node_num() const { return total_node_num_; }
void NodeManager::set_worker_num(const int32_t &worker_num) { meta_data_->worker_num = IntToUint(worker_num); }
void NodeManager::set_worker_num(const uint32_t &worker_num) { meta_data_->worker_num = worker_num; }
void NodeManager::set_server_num(const int32_t &server_num) { meta_data_->server_num = IntToUint(server_num); }
void NodeManager::set_server_num(const uint32_t &server_num) { meta_data_->server_num = server_num; }
int32_t NodeManager::worker_num() const { return UintToInt(meta_data_->worker_num); }
uint32_t NodeManager::worker_num() const { return meta_data_->worker_num; }
int32_t NodeManager::server_num() const { return UintToInt(meta_data_->server_num); }
uint32_t NodeManager::server_num() const { return meta_data_->server_num; }
int32_t NodeManager::next_worker_rank_id() const { return next_worker_rank_id_.load(); }
uint32_t NodeManager::next_worker_rank_id() const { return next_worker_rank_id_.load(); }
int32_t NodeManager::next_server_rank_id() const { return next_server_rank_id_.load(); }
uint32_t NodeManager::next_server_rank_id() const { return next_server_rank_id_.load(); }
void NodeManager::set_next_worker_rank_id(const int32_t &next_worker_rank_id) {
void NodeManager::set_next_worker_rank_id(const uint32_t &next_worker_rank_id) {
this->next_worker_rank_id_ = next_worker_rank_id;
}
void NodeManager::set_next_server_rank_id(const int32_t &next_server_rank_id) {
void NodeManager::set_next_server_rank_id(const uint32_t &next_server_rank_id) {
this->next_server_rank_id_ = next_server_rank_id;
}
void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist = onPersist; }

View File

@ -45,10 +45,9 @@ class NodeManager {
public:
NodeManager()
: initial_total_node_num_(0),
total_node_num_(-1),
current_node_num_(-1),
next_worker_rank_id_(-1),
next_server_rank_id_(-1),
total_node_num_(0),
next_worker_rank_id_(0),
next_server_rank_id_(0),
meta_data_(nullptr),
node_state_(NodeState::NODE_STARTING),
cluster_state_(ClusterState::ClUSTER_STARTING) {}
@ -57,7 +56,7 @@ class NodeManager {
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
void InitNode();
uint32_t NextRankId(const RegisterMessage &register_message, const std::shared_ptr<MessageMeta> &meta);
uint32_t checkIfRankIdExist(const RegisterMessage &register_message, const std::shared_ptr<MessageMeta> &meta);
uint32_t checkIfRankIdExist(const RegisterMessage &register_message);
void UpdateHeartbeat(const std::string &node_id);
std::vector<ServersMeta> FetchServersMeta();
@ -65,7 +64,6 @@ class NodeManager {
std::vector<ServersMeta> FetchAllNodesMeta();
void UpdateCluster();
void CheckClusterTimeout();
void AddFinishNode(const std::string &finish_message);
// After the scheduler receives the scale_out_done node, it will save this node.
@ -92,15 +90,15 @@ class NodeManager {
// After all the nodes are registered successfully, the nodes info can be updated.
void UpdateNodesInfo();
void set_total_node_num(const int32_t &node_num);
const int32_t &total_node_num() const;
void set_worker_num(const int32_t &worker_num);
void set_server_num(const int32_t &server_num);
int32_t worker_num() const;
int32_t server_num() const;
void set_total_node_num(const uint32_t &node_num);
const uint32_t &total_node_num() const;
void set_worker_num(const uint32_t &worker_num);
void set_server_num(const uint32_t &server_num);
uint32_t worker_num() const;
uint32_t server_num() const;
int32_t next_worker_rank_id() const;
int32_t next_server_rank_id() const;
uint32_t next_worker_rank_id() const;
uint32_t next_server_rank_id() const;
void UpdateNodeState(const NodeState &state);
void UpdateClusterState(const ClusterState &state);
@ -119,8 +117,8 @@ class NodeManager {
bool IsNodeRegistered(const std::string &node_id);
void set_registered_nodes_info(const std::unordered_map<std::string, NodeInfo> registered_nodes_info);
void set_next_worker_rank_id(const int32_t &next_worker_rank_id);
void set_next_server_rank_id(const int32_t &next_server_rank_id);
void set_next_worker_rank_id(const uint32_t &next_worker_rank_id);
void set_next_server_rank_id(const uint32_t &next_server_rank_id);
void setPersistCallback(const OnPersist &onPersist);
// Query node information by node id.
@ -140,11 +138,11 @@ class NodeManager {
std::mutex cluster_mutex_;
uint32_t initial_total_node_num_;
int32_t total_node_num_;
int32_t current_node_num_;
uint32_t total_node_num_;
uint32_t current_node_num_;
std::atomic<int> next_worker_rank_id_;
std::atomic<int> next_server_rank_id_;
std::atomic<uint32_t> next_worker_rank_id_;
std::atomic<uint32_t> next_server_rank_id_;
// Whenever a node is registered, it will be stored in this map.
std::unordered_map<std::string, NodeInfo> registered_nodes_info_;

View File

@ -137,14 +137,13 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
CHECK_RETURN_TYPE(heartbeat_message.ParseFromArray(data, SizeToInt(size)));
std::string node_id = heartbeat_message.node_id();
node_manager_.UpdateHeartbeat(node_id);
MS_LOG(DEBUG) << "The scheduler get a heartbeat from node id :" << heartbeat_message.node_id();
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.set_persistent_cmd(PersistentCommand::DEFAULT);
NodeInfo nodeInfo = node_manager_.QueryNodeInfo(node_id);
if (nodeInfo.node_id_ != "") {
if (!nodeInfo.node_id_.empty()) {
// The worker role does not support disaster recovery for the time being.
NodeRole node_role = nodeInfo.node_role_;
if (node_role == NodeRole::SERVER && persistent_cmd_ == PersistentCommand::BEGIN_PERSIST) {
@ -156,6 +155,7 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
persistent_cmd_ = PersistentCommand::DEFAULT;
}
}
node_manager_.UpdateHeartbeat(node_id);
}
MS_LOG(DEBUG) << "The cluster state:" << CommUtil::ClusterStateToString(node_manager_.GetClusterState());
@ -223,14 +223,13 @@ void SchedulerNode::CreateTcpServer() {
const auto client_disconn = [&](const TcpServer &, const TcpConnection &conn) {
int fd = conn.GetFd();
if (register_connection_fd_.count(fd) <= 0) {
return;
if (register_connection_fd_.count(fd) > 0) {
MS_LOG(WARNING) << "remove client fd:" << fd << ", remove client id:" << register_connection_fd_[fd];
(void)register_connection_fd_.erase(fd);
MS_LOG(INFO) << "Register node number is:" << register_connection_fd_.size()
<< ", total node num is:" << node_manager_.total_node_num()
<< ", scale in node size is: " << scale_in_node_ids_.size();
}
MS_LOG(WARNING) << "remove client fd:" << fd << ", remove client id:" << register_connection_fd_[fd];
register_connection_fd_.erase(fd);
MS_LOG(WARNING) << "Register node number is:" << register_connection_fd_.size()
<< ", total node num is:" << node_manager_.total_node_num()
<< ", scale in node size is: " << scale_in_node_ids_.size();
};
server_->SetServerCallback(nullptr, client_disconn, nullptr);
server_->Init();
@ -320,7 +319,7 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server,
auto node_infos = node_manager_.nodes_info();
bool res = SendPrepareBuildingNetwork(node_infos);
if (!res) {
MS_LOG(WARNING) << "Prepare for building network failed!";
MS_LOG(ERROR) << "Prepare for building network failed!";
return;
}
MS_LOG(INFO) << "Prepare for building network success.";
@ -475,23 +474,22 @@ void SchedulerNode::ProcessSendEvent(const std::shared_ptr<TcpServer> &server,
}
bool SchedulerNode::SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) {
std::string timeoutNodeId = "";
uint64_t request_id = AddMessageTrack(node_infos.size());
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
MS_ERROR_IF_NULL_W_RET_VAL(client, false);
auto message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_request_id(request_id);
message_meta->set_cmd(NodeCommand::PREPARE_BUILDING_NETWORK);
SendMetadataMessage send_metadata_message;
send_metadata_message.set_rank_id(kvs.second.rank_id_);
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
send_metadata_message.ByteSizeLong(), kCommTimeoutInThreeSeconds)) {
std::string req_data;
if (!client->SendMessage(message_meta, Protos::RAW, req_data.data(), req_data.length())) {
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(kvs.second.node_role_)
<< " the node id:" << kvs.second.node_id_ << " prepare building network timeout!";
timeoutNodeId += kvs.second.node_id_ + " ";
}
}
return timeoutNodeId.empty();
return Wait(request_id);
}
void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id) {
@ -602,7 +600,7 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
// 1. update cluster timeout
if (!is_ready_ && (std::chrono::steady_clock::now() - start_time >
std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
node_manager_.CheckClusterTimeout();
node_manager_.UpdateClusterState(ClusterState::CLUSTER_EXIT);
}
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateCluster();
@ -645,7 +643,7 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf
}
std::string ip = node_info.ip_;
uint16_t port = node_info.port_;
MS_LOG(DEBUG) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_;
MS_LOG(INFO) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_;
auto client = std::make_shared<TcpClient>(ip, port, config_.get());
MS_EXCEPTION_IF_NULL(client);
client->SetMessageCallback(
@ -1358,7 +1356,7 @@ void SchedulerNode::PersistMetaData() {
}
bool SchedulerNode::CheckIfNodeDisconnected() const {
return UintToInt(register_connection_fd_.size()) != node_manager_.total_node_num();
return UlongToUint(register_connection_fd_.size()) != node_manager_.total_node_num();
}
void SchedulerNode::BroadcastTimeoutEvent() {

View File

@ -32,16 +32,16 @@ bool SchedulerRecovery::Recover() {
// 1. recover worker num
if (recovery_storage_->Exists(kRecoveryWorkerNum)) {
clusterConfig.initial_worker_num =
std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase);
uint32_t initial_worker_num = std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase);
clusterConfig.initial_worker_num = IntToUint(initial_worker_num);
} else {
clusterConfig.initial_worker_num = PSContext::instance()->initial_worker_num();
}
// 2. recover server num
if (recovery_storage_->Exists(kRecoveryServerNum)) {
clusterConfig.initial_server_num =
std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase);
uint32_t initial_server_num = std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase);
clusterConfig.initial_server_num = IntToUint(initial_server_num);
} else {
clusterConfig.initial_server_num = PSContext::instance()->initial_server_num();
}
@ -55,8 +55,8 @@ bool SchedulerRecovery::Recover() {
// 4. recover scheduler port
if (recovery_storage_->Exists(kRecoverySchedulerPort)) {
clusterConfig.scheduler_port =
std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase);
uint16_t scheduler_port = std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase);
clusterConfig.scheduler_port = scheduler_port;
} else {
clusterConfig.scheduler_port = PSContext::instance()->scheduler_port();
}
@ -72,20 +72,23 @@ bool SchedulerRecovery::Recover() {
}
// 5. recover total node num
if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) {
clusterConfig.initial_total_node_num =
uint32_t initial_total_node_num =
std::strtol(scheduler_recovery_storage_->Get(kRecoveryTotalNodeNum, "").c_str(), nullptr, kBase);
clusterConfig.initial_total_node_num = initial_total_node_num;
}
// 6. recover next worker rank id
if (scheduler_recovery_storage_->Exists(kRecoveryNextWorkerRankId)) {
clusterConfig.initial_next_worker_rank_id =
uint32_t initial_next_worker_rank_id =
std::strtol(scheduler_recovery_storage_->Get(kRecoveryNextWorkerRankId, "").c_str(), nullptr, kBase);
clusterConfig.initial_next_worker_rank_id = initial_next_worker_rank_id;
}
// 7. recover next server rank id
if (scheduler_recovery_storage_->Exists(kRecoveryNextServerRankId)) {
clusterConfig.initial_next_server_rank_id =
uint32_t initial_next_server_rank_id =
std::strtol(scheduler_recovery_storage_->Get(kRecoveryNextServerRankId, "").c_str(), nullptr, kBase);
clusterConfig.initial_next_server_rank_id = initial_next_server_rank_id;
}
// 8. recover register nodes info
@ -98,9 +101,11 @@ bool SchedulerRecovery::Recover() {
NodeInfo node_info;
node_info.ip_ = elem.at("ip");
node_info.port_ = std::strtol(port.c_str(), nullptr, kBase);
uint16_t uint_port = std::strtol(port.c_str(), nullptr, kBase);
node_info.port_ = uint_port;
node_info.node_id_ = elem.at("node_id");
node_info.rank_id_ = std::strtol(rank_id.c_str(), nullptr, kBase);
uint32_t uint_rank_id = std::strtol(rank_id.c_str(), nullptr, kBase);
node_info.rank_id_ = uint_rank_id;
node_info.is_alive = CommUtil::StringToBool(elem.at("alive"));
node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role"));

View File

@ -1,5 +1,3 @@
#!/usr/bin/env python3
# coding=UTF-8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -24,30 +22,34 @@ import argparse
import json
import os
import warnings
from enum import Enum
import requests
class Status:
success = "0"
failed = "1"
class Status(Enum):
"""
Response Status
"""
SUCCESS = "0"
FAILED = "1"
class Restful:
class Restful(Enum):
"""
Define restful interface constant
"""
SCALE = "scale"
SCALE_OUT = "scaleout"
SCALE_IN = "scalein"
NODES = "nodes"
GET_INSTANCE_DETAIL = "getInstanceDetail"
NEW_INSTANCE = "newInstance"
QUERY_INSTANCE = "queryInstance"
ENABLE_FLS = "enableFLS"
DISABLE_FLS = "disableFLS"
STATE = "state"
SCALE_OUT_ROLLBACK = "scaleoutRollback"
scale = "scale"
scaleout = "scaleout"
scalein = "scalein"
nodes = "nodes"
getInstanceDetail = "getInstanceDetail"
newInstance = "newInstance"
queryInstance = "queryInstance"
enableFLS = "enableFLS"
disableFLS = "disableFLS"
state = "state"
scaleoutRollback = "scaleoutRollback"
warnings.filterwarnings('ignore')
@ -55,11 +57,9 @@ parser = argparse.ArgumentParser()
parser.add_argument("--http_type", type=str, default="http", help="http or https")
parser.add_argument("--ip", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=6666)
# scaleout scalein nodes
parser.add_argument("--request_name", type=str, default="")
parser.add_argument("--server_num", type=int, default=0)
# "start_fl_job_threshold=20,start_fl_job_time_window=2000..."
parser.add_argument("--instance_param", type=str, default="")
parser.add_argument("--metrics_file_path", type=str, default="/opt/huawei/mindspore/hybrid_albert/metrics.json")
@ -82,12 +82,12 @@ def call_scale():
call cluster scale out or scale in
"""
if server_num == 0:
return process_self_define_json(Status.failed, "error. server_num is 0")
return process_self_define_json(Status.FAILED.value, "error. server_num is 0")
node_ids = json.loads(call_nodes())["result"]
cluster_abstract_node_num = len(node_ids)
if cluster_abstract_node_num == 0:
return process_self_define_json(Status.failed, "error. cluster abstract node num is 0")
return process_self_define_json(Status.FAILED.value, "error. cluster abstract node num is 0")
cluster_server_node_num = 0
cluster_worker_node_num = 0
@ -103,7 +103,7 @@ def call_scale():
else:
pass
if cluster_server_node_num == server_num:
return process_self_define_json(Status.failed, "error. cluster server num is same with server_num.")
return process_self_define_json(Status.FAILED.value, "error. cluster server num is same with server_num.")
if cluster_server_node_num > server_num:
scale_in_len = cluster_server_node_num - server_num
scale_in_node_ids = []
@ -115,56 +115,67 @@ def call_scale():
def call_scaleout(scale_out_server_num, scale_out_worker_num=0):
url = base_url + "scaleout"
"""
call scaleout
"""
url = base_url + Restful.SCALE_OUT.value
data = {"server_num": scale_out_server_num, "worker_num": scale_out_worker_num}
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
result = "scale out server num is " + str(scale_out_server_num)
return process_result_json(Status.success, res_json["message"], result)
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
def call_scaleout_rollback():
url = base_url + Restful.scaleoutRollback
"""
call scaleout rollback
"""
url = base_url + Restful.SCALE_OUT_ROLLBACK.value
res = session.get(url, verify=False)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
return process_self_define_json(Status.success, res_json["message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
def call_scalein(scale_in_node_ids):
"""
call cluster to scale in
"""
if not scale_in_node_ids:
return process_self_define_json(Status.failed, "error. node ids is empty.")
return process_self_define_json(Status.FAILED.value, "error. node ids is empty.")
url = base_url + "scalein"
url = base_url + Restful.SCALE_IN.value
data = {"node_ids": scale_in_node_ids}
res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
result = "scale in node ids is " + str(scale_in_node_ids)
return process_result_json(Status.success, res_json["message"], result)
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
def call_nodes():
url = base_url + Restful.nodes
"""
get nodes info
"""
url = base_url + Restful.NODES.value
res = session.get(url, verify=False)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
return process_result_json(Status.success, res_json["message"], res_json["nodeIds"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["nodeIds"])
def call_get_instance_detail():
"""
call cluster get instance detail
get cluster instance detail
"""
if not os.path.exists(metrics_file_path):
return process_self_define_json(Status.failed, "error. metrics file is not existed.")
return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
ans_json_obj = {}
joined_client_num_list = []
@ -177,7 +188,7 @@ def call_get_instance_detail():
metrics_list = f.readlines()
if not metrics_list:
return process_self_define_json(Status.failed, "error. metrics file has no content")
return process_self_define_json(Status.FAILED.value, "error. metrics file has no content")
for metrics in metrics_list:
json_obj = json.loads(metrics)
@ -190,18 +201,19 @@ def call_get_instance_detail():
last_metrics = metrics_list[len(metrics_list) - 1]
last_metrics_obj = json.loads(last_metrics)
ans_json_obj["code"] = Status.success
ans_json_obj["code"] = Status.SUCCESS.value
ans_json_obj["describe"] = "get instance metrics detail successful."
ans_json_obj["result"] = {}
ans_json_obj["result"]['currentIteration'] = last_metrics_obj['currentIteration']
ans_json_obj["result"]['flIterationNum'] = last_metrics_obj['flIterationNum']
ans_json_obj["result"]['flName'] = last_metrics_obj['flName']
ans_json_obj["result"]['instanceStatus'] = last_metrics_obj['instanceStatus']
ans_json_obj["result"]['iterationExecutionTime'] = iteration_execution_time_list
ans_json_obj["result"]['joinedClientNum'] = joined_client_num_list
ans_json_obj["result"]['rejectedClientNum'] = rejected_client_num_list
ans_json_obj["result"]['metricsAuc'] = metrics_auc_list
ans_json_obj["result"]['metricsLoss'] = metrics_loss_list
ans_json_result = ans_json_obj.get("result")
ans_json_result['currentIteration'] = last_metrics_obj['currentIteration']
ans_json_result['flIterationNum'] = last_metrics_obj['flIterationNum']
ans_json_result['flName'] = last_metrics_obj['flName']
ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus']
ans_json_result['iterationExecutionTime'] = iteration_execution_time_list
ans_json_result['joinedClientNum'] = joined_client_num_list
ans_json_result['rejectedClientNum'] = rejected_client_num_list
ans_json_result['metricsAuc'] = metrics_auc_list
ans_json_result['metricsLoss'] = metrics_loss_list
return json.dumps(ans_json_obj)
@ -211,11 +223,11 @@ def call_new_instance():
call cluster new instance
"""
if instance_param == "":
return process_self_define_json(Status.failed, "error. instance_param is empty.")
return process_self_define_json(Status.FAILED.value, "error. instance_param is empty.")
instance_param_list = instance_param.split(sep=",")
instance_param_json_obj = {}
url = base_url + Restful.newInstance
url = base_url + Restful.NEW_INSTANCE.value
for cur in instance_param_list:
pair = cur.split(sep="=")
instance_param_json_obj[pair[0]] = float(pair[1])
@ -223,84 +235,102 @@ def call_new_instance():
data = json.dumps(instance_param_json_obj)
res = session.post(url, verify=False, data=data)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
return process_self_define_json(Status.success, res_json["message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
def call_query_instance():
url = base_url + Restful.queryInstance
"""
query cluster instance
"""
url = base_url + Restful.QUERY_INSTANCE.value
res = session.post(url, verify=False)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
return process_result_json(Status.success, res_json["message"], res_json["result"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["result"])
def call_enable_fls():
url = base_url + Restful.enableFLS
"""
enable cluster fls
"""
url = base_url + Restful.ENABLE_FLS.value
res = session.post(url, verify=False)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
return process_self_define_json(Status.success, res_json["message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
def call_disable_fls():
url = base_url + Restful.disableFLS
"""
disable cluster fls
"""
url = base_url + Restful.DISABLE_FLS.value
res = session.post(url, verify=False)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
return process_self_define_json(Status.success, res_json["message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
return process_self_define_json(Status.SUCCESS.value, res_json["message"])
def call_state():
url = base_url + Restful.state
"""
get cluster state
"""
url = base_url + Restful.STATE.value
res = session.get(url, verify=False)
res_json = json.loads(res.text)
if res_json["code"] == Status.failed:
return process_self_define_json(Status.failed, res_json["error_message"])
if res_json["code"] == Status.FAILED.value:
return process_self_define_json(Status.FAILED.value, res_json["error_message"])
result = res_json['cluster_state']
return process_result_json(Status.success, res_json["message"], result)
return process_result_json(Status.SUCCESS.value, res_json["message"], result)
def process_result_json(code, describe, result):
"""
process result json
"""
result_dict = {"code": code, "describe": describe, "result": result}
return json.dumps(result_dict)
def process_self_define_json(code, describe):
"""
process self define json
"""
result_dict = {"code": code, "describe": describe}
return json.dumps(result_dict)
if __name__ == '__main__':
if request_name == Restful.scale:
if request_name == Restful.SCALE.value:
print(call_scale())
elif request_name == Restful.nodes:
elif request_name == Restful.NODES.value:
print(call_nodes())
elif request_name == Restful.getInstanceDetail:
elif request_name == Restful.GET_INSTANCE_DETAIL.value:
print(call_get_instance_detail())
elif request_name == Restful.newInstance:
elif request_name == Restful.NEW_INSTANCE.value:
print(call_new_instance())
elif request_name == Restful.queryInstance:
elif request_name == Restful.QUERY_INSTANCE.value:
print(call_query_instance())
elif request_name == Restful.enableFLS:
elif request_name == Restful.ENABLE_FLS.value:
print(call_enable_fls())
elif request_name == Restful.disableFLS:
elif request_name == Restful.DISABLE_FLS.value:
print(call_disable_fls())
elif request_name == Restful.state:
elif request_name == Restful.STATE.value:
print(call_state())
elif request_name == Restful.scaleoutRollback:
elif request_name == Restful.SCALE_OUT_ROLLBACK.value:
print(call_scaleout_rollback())
else: