scheduler added client

This commit is contained in:
chendongsheng 2021-05-12 10:22:08 +08:00
parent 9f0d6ec8da
commit 54a331d103
27 changed files with 544 additions and 286 deletions

View File

@ -55,6 +55,10 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const
}
node_info_.rank_id_ = register_resp_message.rank_id();
// Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
// scheduler is alive
UpdateSchedulerTime();
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
<< " registered scheduler success!";
}
@ -84,9 +88,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &mess
return Wait(request_id, timeout);
}
void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) {
on_node_event_message_ = on_node_event_message;
}
void AbstractNode::set_event_callback(const OnNodeEventMessage &event) { on_node_event_message_ = event; }
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
@ -211,16 +213,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
return Wait(request_id, timeout);
}
bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
return ret;
});
message_tracker_.erase(request_id);
return res;
}
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
@ -294,19 +286,19 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
} else {
UpdateSchedulerTime();
}
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
}
});
heart_beat_thread_->detach();
}
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::HEARTBEAT);
HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_);
heartbeat_message.set_is_node_finish(is_node_finish);
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
heartbeat_message.ByteSizeLong())) {
@ -338,32 +330,21 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromArray(data, size);
is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) {
current_cluster_state_ = heartbeat_resp_message.cluster_state();
if (current_cluster_state_ == ClusterState::CLUSTER_READY) {
is_ready_ = true;
wait_start_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
}
if (heartbeat_resp_message.is_cluster_finish()) {
Heartbeat(client_to_scheduler_, true);
is_finish_ = true;
wait_finish_cond_.notify_all();
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
}
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
if (is_timeout_ && on_node_event_message_) {
if (current_cluster_state_ == ClusterState::CLUSTER_TIMEOUT && on_node_event_message_) {
is_ready_ = true;
wait_start_cond_.notify_all();
on_node_event_message_(NodeEvent::CLUSTER_TIMEOUT);
}
if (heartbeat_resp_message.is_node_timeout() && on_node_event_message_) {
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
}
}
void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::FETCH_SERVER);
meta->set_cmd(NodeCommand::FETCH_METADATA);
FetchServersMessage fetch_servers;
fetch_servers.set_node_id(node_info_.node_id_);
@ -379,11 +360,39 @@ void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, co
FetchServersRespMessage fetch_servers_resp_message;
fetch_servers_resp_message.ParseFromArray(data, size);
nodes_address_.clear();
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
}
}
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
SendMetadataMessage send_meta_message;
send_meta_message.ParseFromArray(data, size);
nodes_address_.clear();
MS_LOG(ERROR) << "send metadata size:" << send_meta_message.servers_meta().size();
for (const auto &it : send_meta_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
}
server_->SendMessage(conn, meta, Protos::RAW, data, size);
is_ready_ = true;
wait_start_cond_.notify_all();
}
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
is_finish_ = true;
wait_finish_cond_.notify_all();
}
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
@ -478,42 +487,6 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
}
}
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = AddMessageTrack(1);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return request_id;
}
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
@ -570,14 +543,6 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
message_callbacks_[request_id] = callback;
}
void AbstractNode::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = meta->request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
@ -639,20 +604,25 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
handlers_[NodeCommand::FINISH] = nullptr;
}
uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(expected_response, 0);
return request_id;
void AbstractNode::InitServerHandler() {
server_handler_[NodeCommand::SEND_METADATA] = &AbstractNode::ProcessSendMetadata;
server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
server_handler_[NodeCommand::SEND_DATA] = nullptr;
server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
}
bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
void AbstractNode::InitNode(const NodeRole &role) {
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = role;
node_info_.ip_ = server_->BoundIp();
node_info_.port_ = server_->BoundPort();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
<< ", the port:" << server_->BoundPort();
}
} // namespace core
} // namespace ps

View File

@ -33,17 +33,25 @@ namespace ps {
namespace core {
class AbstractNode : public Node {
public:
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
AbstractNode()
: heart_beat_thread_(nullptr),
client_to_scheduler_thread_(nullptr),
client_to_scheduler_(nullptr),
server_(nullptr),
server_thread_(nullptr) {}
~AbstractNode() override = default;
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
typedef void (AbstractNode::*ServerHandler)(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
void set_event_callback(const OnNodeEventMessage &event);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kCommTimeoutInSeconds);
@ -54,7 +62,6 @@ class AbstractNode : public Node {
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
@ -63,13 +70,18 @@ class AbstractNode : public Node {
protected:
void Register(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const;
@ -77,39 +89,29 @@ class AbstractNode : public Node {
bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler();
uint64_t AddMessageTrack(const uint32_t &expected_response);
bool CheckMessageTrack(const uint64_t &request_id);
void InitServerHandler();
void InitNode(const NodeRole &role);
std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_;
OnNodeEventMessage on_node_event_message_;
// the key is: <node_role,rank_id>, the value is: <ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
std::mutex client_mutex_;
// the map's key is: rank_id
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
// the key is: request_id, the value is: <expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
// the key is: request_id, the value is: <rank_id, RecvMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
@ -132,6 +134,13 @@ class AbstractNode : public Node {
std::mutex rank_request_ids_mutex;
timeval scheduler_time_{0, 0};
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
std::unordered_map<NodeCommand, ServerHandler> server_handler_;
std::unordered_map<NodeEvent, bool> is_event_send_;
std::mutex is_event_send_mutex_;
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
};
} // namespace core
} // namespace ps

View File

@ -31,24 +31,17 @@ namespace core {
* Configuration information read through environment variables and configuration files, generally immutable
*/
struct ClusterConfig {
ClusterConfig()
: initial_worker_num(0),
initial_server_num(0),
explicit ClusterConfig(const uint32_t &worker_num, const uint32_t &server_num, std::string host, const uint16_t &port)
: initial_worker_num(worker_num),
initial_server_num(server_num),
heartbeat_interval(3),
scheduler_host(""),
scheduler_port(0),
scheduler_host(host),
scheduler_port(port),
heartbeat_timeout(30),
cluster_available_timeout(300),
connect_interval(100),
scheduler_timeout(30) {}
void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string host, const uint16_t &port) {
initial_worker_num = worker_num;
initial_server_num = server_num;
scheduler_host = host;
scheduler_port = port;
}
// Configure through environment variables:MS_WORKER_NUM
uint32_t initial_worker_num;
// Configure through environment variables:MS_SERVER_NUM

View File

@ -31,15 +31,10 @@ namespace core {
* The metadata information of the cluster, stored in the scheduler, is generally used for scale out and scale in.
*/
struct ClusterMetadata {
ClusterMetadata() : worker_num_(0), server_num_(0) {}
ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {}
void Init(const uint32_t &worker_num, const uint32_t &server_num) {
worker_num_ = worker_num;
server_num_ = server_num;
}
uint32_t worker_num_;
uint32_t server_num_;
uint32_t worker_num;
uint32_t server_num;
};
} // namespace core
} // namespace ps

View File

@ -51,8 +51,8 @@
#include "proto/ps.pb.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/cluster_config.h"
#include "ps/ps_context.h"
#include "utils/log_adapter.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace ps {

View File

@ -30,7 +30,7 @@ enum class Command {
REGISTER = 1,
HEARTBEAT = 2,
SEND_DATA = 3,
FETCH_SERVER = 4,
FETCH_METADATA = 4,
FINISH = 5,
COLLECTIVE_SEND_DATA = 6
};

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_REQUEST_PROCESS_RESULT_CODE_H_
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_REQUEST_PROCESS_RESULT_CODE_H_
#include <string>
#include <memory>
@ -99,4 +99,4 @@ class RequestProcessResult {
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_
#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_REQUEST_PROCESS_RESULT_CODE_H_

View File

@ -17,8 +17,6 @@
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_CLIENT_H_
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_CLIENT_H_
#include "ps/core/communicator/tcp_message_handler.h"
#include <event2/event.h>
#include <event2/bufferevent.h>
#include <event2/thread.h>
@ -33,13 +31,13 @@
#include <atomic>
#include <condition_variable>
#include "ps/core/cluster_metadata.h"
#include "ps/core/cluster_config.h"
#include "utils/convert_utils_base.h"
#include "ps/core/comm_util.h"
#include "ps/core/communicator/ssl_wrapper.h"
#include "ps/constants.h"
#include "ps/ps_context.h"
#include "ps/core/communicator/tcp_message_handler.h"
namespace mindspore {
namespace ps {

View File

@ -37,7 +37,6 @@
#include "ps/core/communicator/tcp_message_handler.h"
#include "ps/core/communicator/ssl_wrapper.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/cluster_config.h"
#include "utils/convert_utils_base.h"
#include "ps/core/comm_util.h"

View File

@ -36,6 +36,72 @@ bool Node::WaitForStart(const uint32_t &timeout) {
});
return res;
}
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) {
uint64_t request_id = AddMessageTrack(1);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return request_id;
}
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout);
}
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
return ret;
});
message_tracker_.erase(request_id);
return res;
}
uint64_t Node::AddMessageTrack(const uint32_t &expected_response) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(expected_response, 0);
return request_id;
}
bool Node::CheckMessageTrack(const uint64_t &request_id) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
}
void Node::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = meta->request_id();
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -47,13 +47,20 @@ class Node {
Node()
: is_ready_(false),
is_finish_(false),
is_timeout_(false),
is_node_ready_scale_out_(false),
is_node_ready_scale_in_(false),
is_cluster_ready_scale_out_(false),
is_cluster_ready_scale_in_(false),
update_local_servers_(false),
is_already_stopped_(true),
is_already_finished_(false),
next_request_id_(0) {}
next_request_id_(0),
current_node_state_(NodeState::NODE_STARTING),
current_cluster_state_(ClusterState::ClUSTER_STARTING) {}
virtual ~Node() = default;
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
using MessageCallback = std::function<void()>;
virtual bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) = 0;
@ -64,13 +71,37 @@ class Node {
uint32_t rank_id() const;
NodeRole role() const;
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
protected:
bool WaitForStart(const uint32_t &timeout);
// Send data synchronously
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
// Send data asynchronously
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
uint64_t AddMessageTrack(const uint32_t &expected_response);
bool CheckMessageTrack(const uint64_t &request_id);
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
NodeInfo node_info_;
std::atomic<bool> is_ready_;
std::atomic<bool> is_finish_;
std::atomic<bool> is_timeout_;
std::atomic<bool> is_node_ready_scale_out_;
std::atomic<bool> is_node_ready_scale_in_;
std::atomic<bool> is_cluster_ready_scale_out_;
std::atomic<bool> is_cluster_ready_scale_in_;
// Determine whether to update the ip and port of the locally cached servers.
std::atomic<bool> update_local_servers_;
std::atomic<bool> is_already_stopped_;
std::atomic<bool> is_already_finished_;
std::atomic_uint64_t next_request_id_;
@ -80,6 +111,14 @@ class Node {
std::mutex wait_finish_mutex_;
std::condition_variable wait_finish_cond_;
std::mutex finish_mutex_;
// the key is: request_id, the value is: <expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;
NodeState current_node_state_;
ClusterState current_cluster_state_;
};
} // namespace core
} // namespace ps

View File

@ -25,7 +25,15 @@
namespace mindspore {
namespace ps {
namespace core {
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 };
enum class NodeEvent {
CLUSTER_TIMEOUT = 0,
NODE_TIMEOUT = 1,
SCHEDULER_TIMEOUT = 2,
READY_FOR_SCALE_OUT = 3,
READY_FOR_SCALE_IN = 4,
CLUSTER_SCALE_OUT_DONE = 5,
CLUSTER_SCALE_IN_DONE = 6
};
struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}

View File

@ -19,9 +19,12 @@
namespace mindspore {
namespace ps {
namespace core {
void NodeManager::InitNodeNum() {
total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
PSContext::instance()->cluster_config().initial_worker_num;
void NodeManager::InitNode() {
initial_total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
PSContext::instance()->cluster_config().initial_worker_num;
meta_data_ = std::make_unique<ClusterMetadata>(PSContext::instance()->cluster_config().initial_worker_num,
PSContext::instance()->cluster_config().initial_server_num);
total_node_num_ = initial_total_node_num_;
}
int NodeManager::NextRankId(const RegisterMessage &register_message) {
@ -40,8 +43,8 @@ int NodeManager::NextRankId(const RegisterMessage &register_message) {
uint32_t port = register_message.port();
rank_id = ++next_server_rank_id_;
if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_server_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of servers.";
if (IntToUint(rank_id) >= meta_data_->server_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
rank_id = -1;
--next_server_rank_id_;
}
@ -55,9 +58,11 @@ int NodeManager::NextRankId(const RegisterMessage &register_message) {
MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
<< " assign rank id:" << rank_id;
} else if (register_message.role() == NodeRole::WORKER) {
const std::string &ip = register_message.ip();
uint32_t port = register_message.port();
rank_id = ++next_worker_rank_id_;
if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_worker_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of workers.";
if (IntToUint(rank_id) >= meta_data_->worker_num) {
MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
rank_id = -1;
--next_worker_rank_id_;
}
@ -65,6 +70,8 @@ int NodeManager::NextRankId(const RegisterMessage &register_message) {
node_info.node_role_ = NodeRole::WORKER;
node_info.node_id_ = node_id;
node_info.rank_id_ = rank_id;
node_info.ip_ = ip;
node_info.port_ = port;
nodes_info_[node_id] = node_info;
MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
}
@ -81,9 +88,11 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
}
void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); }
void NodeManager::UpdateNodeScaleInState(const std::string &node_id) { heartbeats_scale_in_nodes_.insert(node_id); }
bool NodeManager::CheckNodesFinishState() { return heartbeats_finish_nodes_.size() == nodes_info_.size(); }
bool NodeManager::CheckNodesScaluOutState() { return SizeToInt(heartbeats_scale_out_nodes_.size()) == total_node_num_; }
bool NodeManager::CheckNodesScaleInState() { return SizeToInt(heartbeats_scale_in_nodes_.size()) == total_node_num_; }
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
std::vector<ServersMeta> servers_meta_list;
@ -99,7 +108,7 @@ std::vector<ServersMeta> NodeManager::FetchServersMeta() {
return servers_meta_list;
}
void NodeManager::UpdateClusterState() {
void NodeManager::UpdateCluster() {
// 1. update cluster timeout state
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
@ -111,48 +120,61 @@ void NodeManager::UpdateClusterState() {
}
}
if (!timeout_nodes_info_.empty()) {
is_node_timeout_ = true;
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
finish_nodes_id_.insert(it->first);
}
}
// 2. update cluster finish state
if (finish_nodes_id_.size() == total_node_num_ || SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
is_cluster_finish_ = true;
is_cluster_ready_ = true;
}
// 3. update cluster ready state
if (nodes_info_.size() == total_node_num_) {
is_cluster_ready_ = true;
if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
UpdateClusterState(ClusterState::CLUSTER_FINISH);
}
}
void NodeManager::CheckClusterTimeout() {
if (total_node_num_ != nodes_info_.size()) {
if (total_node_num_ != SizeToInt(nodes_info_.size())) {
MS_LOG(WARNING) << "The cluster is not ready after "
<< PSContext::instance()->cluster_config().cluster_available_timeout
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
<< nodes_info_.size();
current_node_num_ = nodes_info_.size();
is_cluster_timeout_ = true;
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
}
}
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
std::unordered_map<std::string, NodeInfo> NodeManager::nodes_info() { return nodes_info_; }
bool NodeManager::CheckRegisterNum() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); }
bool NodeManager::CheckFinishNum() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); }
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_.load(); }
void NodeManager::UpdateNodeState(const NodeState &state) {
std::lock_guard<std::mutex> lk(node_mutex_);
node_state_ = state;
}
bool NodeManager::is_node_timeout() { return is_node_timeout_.load(); }
void NodeManager::UpdateClusterState(const ClusterState &state) {
std::lock_guard<std::mutex> lk(cluster_mutex_);
cluster_state_ = state;
}
void NodeManager::set_cluster_timeout(bool is_cluster_timeout) { is_cluster_timeout_ = is_cluster_timeout; }
NodeState NodeManager::GetNodeState() {
std::lock_guard<std::mutex> lk(node_mutex_);
return node_state_;
}
ClusterState NodeManager::GetClusterState() {
std::lock_guard<std::mutex> lk(cluster_mutex_);
return cluster_state_;
}
void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
const int32_t &NodeManager::total_node_num() { return total_node_num_; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -34,6 +34,7 @@
#include "ps/core/node.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
#include "ps/core/cluster_metadata.h"
namespace mindspore {
namespace ps {
@ -41,52 +42,75 @@ namespace core {
class NodeManager {
public:
NodeManager()
: is_cluster_ready_(false),
is_cluster_finish_(false),
is_cluster_timeout_(false),
is_node_timeout_(false),
total_node_num_(0),
: initial_total_node_num_(0),
total_node_num_(-1),
current_node_num_(-1),
next_worker_rank_id_(-1),
next_server_rank_id_(-1) {}
next_server_rank_id_(-1),
meta_data_(nullptr),
node_state_(NodeState::NODE_STARTING),
cluster_state_(ClusterState::ClUSTER_STARTING) {}
virtual ~NodeManager() = default;
enum ClusterState { STARTING, STARTED, FAILED, STOPPING, STOPPED };
void InitNodeNum();
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
void InitNode();
int NextRankId(const RegisterMessage &register_message);
void UpdateHeartbeat(const std::string &node_id);
void UpdateNodeFinishState(const std::string &node_id);
bool CheckNodesFinishState();
bool CheckNodesScaluOutState();
void UpdateNodeScaleInState(const std::string &node_id);
bool CheckNodesScaleInState();
std::vector<ServersMeta> FetchServersMeta();
void UpdateClusterState();
void UpdateCluster();
void CheckClusterTimeout();
void AddFinishNode(const std::string &finish_message);
std::unordered_map<std::string, NodeInfo> nodes_info();
bool is_cluster_ready();
bool is_cluster_finish();
bool is_cluster_timeout();
bool is_node_timeout();
void set_cluster_timeout(bool is_cluster_timeout);
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
bool CheckRegisterNum();
// When workers and servers send a finish message to the scheduler, the scheduler will collect the number of
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
bool CheckFinishNum();
std::unordered_map<std::string, NodeInfo> &nodes_info();
void set_total_node_num(const int32_t &node_num);
const int32_t &total_node_num();
void UpdateNodeState(const NodeState &state);
void UpdateClusterState(const ClusterState &state);
NodeState GetNodeState();
ClusterState GetClusterState();
private:
std::atomic<bool> is_cluster_ready_;
std::atomic<bool> is_cluster_finish_;
std::atomic<bool> is_cluster_timeout_;
std::atomic<bool> is_node_timeout_;
uint32_t total_node_num_;
std::mutex node_mutex_;
std::mutex cluster_mutex_;
uint32_t initial_total_node_num_;
int32_t total_node_num_;
int32_t current_node_num_;
std::atomic<int> next_worker_rank_id_;
std::atomic<int> next_server_rank_id_;
// worker nodes and server nodes
std::unordered_map<std::string, NodeInfo> nodes_info_;
std::mutex assign_rank_id_mutex_;
std::mutex heartbeat_mutex_;
std::unordered_map<std::string, timeval> heartbeats_;
std::unordered_set<std::string> heartbeats_finish_nodes_;
std::unordered_set<std::string> heartbeats_scale_out_nodes_;
std::unordered_set<std::string> heartbeats_scale_in_nodes_;
// timeout nodes
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
std::unordered_set<std::string> finish_nodes_id_;
std::unique_ptr<ClusterMetadata> meta_data_;
NodeState node_state_;
ClusterState cluster_state_;
};
} // namespace core
} // namespace ps

View File

@ -23,9 +23,12 @@ enum NodeCommand {
REGISTER = 1;
HEARTBEAT = 2;
SEND_DATA = 3;
FETCH_SERVER = 4;
// The worker or server asks the scheduler for metadata
FETCH_METADATA = 4;
FINISH = 5;
COLLECTIVE_SEND_DATA = 6;
// The scheduler actively sends metadata to the worker and server
SEND_METADATA = 7;
}
enum NodeRole {
@ -66,15 +69,27 @@ message RegisterRespMessage {
message HeartbeatMessage {
// the current Node unique id:0,1,2...
string node_id = 1;
bool is_node_finish = 2;
}
enum NodeState {
NODE_STARTING = 0;
NODE_FINISH = 1;
NODE_READY = 2;
NODE_TIMEOUT = 3;
}
enum ClusterState {
ClUSTER_STARTING = 0;
CLUSTER_READY = 1;
CLUSTER_FINISH = 2;
CLUSTER_TIMEOUT = 3;
CLUSTER_SCALE_OUT = 4;
CLUSTER_SCALE_IN = 5;
CLUSTER_FAILURE = 6;
}
message HeartbeatRespMessage {
// Is the entire system ready to use.
bool is_cluster_ready = 1;
bool is_cluster_finish = 2;
bool is_cluster_timeout = 3;
bool is_node_timeout = 4;
ClusterState cluster_state = 1;
}
message FetchServersMessage {
@ -92,6 +107,10 @@ message ServersMeta {
}
message SendMetadataMessage {
repeated ServersMeta servers_meta = 1;
}
message FinishMessage {
// the current Node unique id:0,1,2...
string node_id = 1;

View File

@ -47,21 +47,9 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
if (heartbeat_message.is_node_finish()) {
node_manager_.UpdateNodeFinishState(heartbeat_message.node_id());
}
if (heartbeat_message.is_node_finish() && node_manager_.CheckNodesFinishState()) {
MS_LOG(INFO) << "The scheduler node receive all the finish cmd!";
is_finish_ = true;
wait_finish_cond_.notify_all();
}
HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.set_is_cluster_ready(node_manager_.is_cluster_ready());
heartbeat_resp_message.set_is_cluster_finish(node_manager_.is_cluster_finish());
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
heartbeat_resp_message.set_cluster_state(node_manager_.GetClusterState());
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
heartbeat_resp_message.ByteSizeLong());
@ -81,11 +69,11 @@ void SchedulerNode::InitCommandHandler() {
handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers;
handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata;
}
void SchedulerNode::CreateTcpServer() {
node_manager_.InitNodeNum();
node_manager_.InitNode();
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
@ -131,6 +119,17 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
register_resp_message.ByteSizeLong());
if (node_manager_.CheckRegisterNum()) {
is_ready_ = true;
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
SendMetadata(client);
}
current_cluster_state_ = ClusterState::CLUSTER_READY;
wait_start_cond_.notify_all();
}
}
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
@ -143,10 +142,20 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
node_manager_.AddFinishNode(*finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << *finish_message;
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
if (node_manager_.CheckFinishNum()) {
auto node_infos = node_manager_.nodes_info();
for (const auto &kvs : node_infos) {
auto client = GetOrCreateClient(kvs.second);
SendFinish(client);
}
is_finish_ = true;
current_cluster_state_ = ClusterState::CLUSTER_FINISH;
wait_finish_cond_.notify_all();
}
}
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
@ -160,26 +169,59 @@ void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::
fetch_servers_message.ByteSizeLong());
}
void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_METADATA);
SendMetadataMessage send_metadata_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
MS_LOG(ERROR) << "the list size:" << servers_meta_list.size();
*send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
send_metadata_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send metadata timeout!";
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is sending metadata to workers and servers!";
}
void SchedulerNode::SendFinish(const std::shared_ptr<TcpClient> &client) {
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::FINISH);
// The scheduler does not need to bring any data when sending the finish command
std::string resp_data;
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send finish timeout!";
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is sending finish to workers and servers!";
}
void SchedulerNode::StartUpdateClusterStateTimer() {
MS_LOG(WARNING) << "The scheduler start a heartbeat timer!";
update_state_thread_ = std::make_unique<std::thread>([&]() {
auto start_time = std::chrono::steady_clock::now();
while (!is_finish_.load()) {
// 1. update cluster timeout
if (!node_manager_.is_cluster_ready() &&
(std::chrono::steady_clock::now() - start_time >
std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
if (!is_ready_ && (std::chrono::steady_clock::now() - start_time >
std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
node_manager_.CheckClusterTimeout();
}
// 2. update cluster state
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
node_manager_.UpdateClusterState();
if (node_manager_.is_cluster_ready()) {
is_ready_ = true;
wait_start_cond_.notify_all();
}
if (node_manager_.is_cluster_finish()) {
node_manager_.UpdateCluster();
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_FINISH) {
std::this_thread::sleep_for(
std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * 2));
is_finish_ = true;
@ -189,6 +231,29 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
});
}
const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInfo &node_info) {
if (connected_nodes_.count(node_info.node_id_)) {
return connected_nodes_[node_info.node_id_];
} else {
std::string ip = node_info.ip_;
uint16_t port = node_info.port_;
auto client = std::make_shared<TcpClient>(ip, port);
MS_LOG(ERROR) << "the ip:" << node_info.ip_ << ", the port:" << node_info.port_;
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) { NotifyMessageArrival(meta); });
client->Init();
if (is_client_started_ == false) {
is_client_started_ = true;
client_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The node start a tcp client!";
client->Start();
});
}
connected_nodes_[node_info.node_id_] = client;
return connected_nodes_[node_info.node_id_];
}
}
bool SchedulerNode::Stop() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
@ -196,6 +261,12 @@ bool SchedulerNode::Stop() {
update_state_thread_->join();
server_->Stop();
scheduler_thread_->join();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_thread_->join();
is_ready_ = true;
}
return true;

View File

@ -27,20 +27,31 @@
#include <mutex>
#include <unordered_map>
#include "ps/core/cluster_metadata.h"
#include "ps/core/cluster_config.h"
#include "ps/ps_context.h"
#include "ps/core/communicator/tcp_client.h"
#include "ps/core/communicator/tcp_server.h"
#include "ps/core/node_manager.h"
#include "ps/core/node.h"
#include "ps/core/communicator/request_process_result_code.h"
#include "ps/core/communicator/http_message_handler.h"
#include "ps/constants.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/communicator/http_server.h"
namespace mindspore {
namespace ps {
namespace core {
class SchedulerNode : public Node {
public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
SchedulerNode()
: server_(nullptr),
scheduler_thread_(nullptr),
update_state_thread_(nullptr),
restful_thread_(nullptr),
http_server_(nullptr),
client_thread_(nullptr),
is_client_started_(false) {}
~SchedulerNode() override;
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
@ -54,15 +65,22 @@ class SchedulerNode : public Node {
void Initialize();
void InitCommandHandler();
void CreateTcpServer();
void StartUpdateClusterStateTimer();
const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void StartUpdateClusterStateTimer();
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
// After scheduler collects all registered message, it actively sends metadata to workers and servers.
void SendMetadata(const std::shared_ptr<TcpClient> &client);
// // After scheduler collects all finish message, it actively sends finish message to workers and servers.
void SendFinish(const std::shared_ptr<TcpClient> &client);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_;
@ -70,6 +88,14 @@ class SchedulerNode : public Node {
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
NodeManager node_manager_;
std::unique_ptr<std::thread> restful_thread_;
std::shared_ptr<HttpServer> http_server_;
std::unordered_map<std::string, std::shared_ptr<TcpClient>> connected_nodes_;
std::unique_ptr<std::thread> client_thread_;
std::atomic<bool> is_client_started_;
};
} // namespace core
} // namespace ps

View File

@ -32,11 +32,6 @@ bool ServerNode::Start(const uint32_t &timeout) {
}
MS_LOG(INFO) << "The cluster is ready to use!";
// If the cluster is ready to use, then Get the address of all the servers
if (!is_timeout_.load()) {
FetchServers(client_to_scheduler_);
MS_LOG(INFO) << "Server node get all the servers address successful!";
}
MsException::Instance().CheckException();
MS_LOG(INFO) << "Start the node is successful!";
return true;
@ -63,34 +58,32 @@ void ServerNode::CreateTcpServer() {
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA:
ProcessSendData(conn, meta, protos, data, size);
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(conn, meta, data, size);
RunReceiveCallback(meta, protos, data, size);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
if (server_handler_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
if (meta->cmd() == NodeCommand::COLLECTIVE_SEND_DATA) {
ProcessCollectiveSendData(conn, meta, data, size);
RunReceiveCallback(meta, protos, data, size);
} else if (meta->cmd() == NodeCommand::SEND_DATA) {
ProcessSendData(conn, meta, protos, data, size);
} else {
const auto &handler_ptr = server_handler_[meta->cmd()];
(this->*handler_ptr)(conn, meta, protos, data, size);
}
});
server_->Init();
server_thread_ = std::make_unique<std::thread>([&]() {
server_thread_ = std::make_unique<std::thread>([this]() {
MS_LOG(INFO) << "The server node start a tcp server!";
server_->Start();
this->server_->Start();
});
}
void ServerNode::Initialize() {
InitServerHandler();
CreateTcpServer();
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::SERVER;
node_info_.ip_ = server_->BoundIp();
node_info_.port_ = server_->BoundPort();
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " is generate uuid is:" << node_info_.node_id_;
InitNode(NodeRole::SERVER);
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Server node init client timeout!";

View File

@ -43,7 +43,7 @@ constexpr char kHttpCommunicator[] = "HTTP";
class ServerNode : public AbstractNode {
public:
ServerNode() : server_(nullptr), server_thread_(nullptr) {}
ServerNode() = default;
~ServerNode() override = default;
@ -71,8 +71,6 @@ class ServerNode : public AbstractNode {
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size);
std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;
RequestHandler request_handler_;
std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
std::mutex communicator_mutex_;

View File

@ -32,11 +32,6 @@ bool WorkerNode::Start(const uint32_t &timeout) {
}
MS_LOG(INFO) << "The node is ready to fetch servers!";
// If the cluster is ready to use, then Get the address of all the servers
if (!is_timeout_.load()) {
FetchServers(client_to_scheduler_);
MS_LOG(INFO) << "Worker node get all the servers address successful!";
}
MsException::Instance().CheckException();
MS_LOG(INFO) << "The Worker node has successfully started.";
return true;
@ -44,10 +39,9 @@ bool WorkerNode::Start(const uint32_t &timeout) {
void WorkerNode::Initialize() {
is_already_stopped_ = false;
node_info_.node_id_ = CommUtil::GenerateUUID();
node_info_.node_role_ = NodeRole::WORKER;
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_;
InitServerHandler();
CreateTcpServer();
InitNode(NodeRole::WORKER);
InitCommandHandler();
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Worker node init client timeout!";
@ -55,11 +49,30 @@ void WorkerNode::Initialize() {
MS_LOG(INFO) << "Worker node init client successful!";
}
void WorkerNode::CreateTcpServer() {
std::string interface;
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
if (server_handler_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
const auto &handler_ptr = server_handler_[meta->cmd()];
(this->*handler_ptr)(conn, meta, protos, data, size);
});
server_->Init();
server_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The worker node start a tcp server!";
server_->Start();
});
}
bool WorkerNode::Stop() {
if (!is_already_stopped_.load()) {
MS_LOG(INFO) << "Stop worker node!";
is_ready_ = true;
is_timeout_ = true;
is_finish_ = true;
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
@ -67,6 +80,8 @@ bool WorkerNode::Stop() {
connected_node.second->Stop();
}
}
server_->Stop();
server_thread_->join();
is_already_stopped_ = true;
}
return true;

View File

@ -24,9 +24,7 @@
#include <utility>
#include <algorithm>
#include "ps/core/cluster_metadata.h"
#include "ps/core/cluster_config.h"
#include "ps/ps_context.h"
#include "ps/core/communicator/tcp_client.h"
#include "ps/core/communicator/tcp_server.h"
#include "ps/core/abstract_node.h"
@ -45,6 +43,7 @@ class WorkerNode : public AbstractNode {
private:
void Initialize();
void CreateTcpServer();
};
} // namespace core
} // namespace ps

View File

@ -52,7 +52,7 @@ void PSContext::SetPSEnable(bool enabled) {
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
cluster_config_.Init(worker_num_, server_num_, scheduler_host_, scheduler_port_);
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
} else {
MS_LOG(INFO) << "PS mode is disabled.";
is_worker_ = false;
@ -312,6 +312,11 @@ bool PSContext::enable_ssl() const { return enable_ssl_; }
void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
core::ClusterConfig &PSContext::cluster_config() { return cluster_config_; }
core::ClusterConfig &PSContext::cluster_config() {
if (cluster_config_ == nullptr) {
MS_LOG(EXCEPTION) << "The cluster config is empty.";
}
return *cluster_config_;
}
} // namespace ps
} // namespace mindspore

View File

@ -176,7 +176,8 @@ class PSContext {
client_epoch_num_(25),
client_batch_size_(32),
client_learning_rate_(0.001),
secure_aggregation_(false) {}
secure_aggregation_(false),
cluster_config_(nullptr) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@ -234,9 +235,8 @@ class PSContext {
bool secure_aggregation_;
// The cluster config read through environment variables, the value does not change.
core::ClusterConfig cluster_config_;
std::unique_ptr<core::ClusterConfig> cluster_config_;
};
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CONTEXT_H_

View File

@ -127,7 +127,7 @@ bool Server::InitCommunicatorWithServer() {
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
MS_EXCEPTION_IF_NULL(tcp_comm);
tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() {
tcp_comm->RegisterEventCallback(core::NodeEvent::CLUSTER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
"started during network building phase.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
@ -135,14 +135,14 @@ bool Server::InitCommunicatorWithServer() {
communicator_with_server_->Stop();
});
tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() {
tcp_comm->RegisterEventCallback(core::NodeEvent::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() {
tcp_comm->RegisterEventCallback(core::NodeEvent::NODE_TIMEOUT, [&]() {
MS_LOG(ERROR)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase.";

View File

@ -17,6 +17,7 @@
#include "common/common_test.h"
#include "ps/core/node.h"
#include "ps/core/scheduler_node.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ps {
@ -31,7 +32,16 @@ class TestClusterAvailableTimeout : public UT::Common {
};
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
PSContext::instance()->cluster_config().Init(1, 1, "127.0.0.1", 9999);
std::string worker_num = "1";
std::string server_num = "1";
std::string host = "127.0.0.1";
std::string port = "9999";
common::SetEnv(kEnvWorkerNum, worker_num.c_str());
common::SetEnv(kEnvPServerNum, server_num.c_str());
common::SetEnv(kEnvSchedulerHost, host.c_str());
common::SetEnv(kEnvSchedulerPort, port.c_str());
PSContext::instance()->SetPSEnable(true);
PSContext::instance()->cluster_config().cluster_available_timeout = 3;
MS_LOG(INFO) << "The timeout is:" << PSContext::instance()->cluster_config().cluster_available_timeout;
SchedulerNode node;
}

View File

@ -18,9 +18,9 @@
#include <string>
#include "common/common_test.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/cluster_config.h"
#include "ps/ps_context.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ps {
@ -29,14 +29,21 @@ class TestClusterConfig : public UT::Common {
public:
TestClusterConfig() = default;
virtual ~TestClusterConfig() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestClusterConfig, HeartbeatInterval) {
PSContext::instance()->cluster_config().Init(2, 2, "127.0.0.1", 8080);
PSContext::instance()->cluster_config().heartbeat_interval = 100;
std::string worker_num = "1";
std::string server_num = "1";
std::string host = "127.0.0.1";
std::string port = "9999";
common::SetEnv(kEnvWorkerNum, worker_num.c_str());
common::SetEnv(kEnvPServerNum, server_num.c_str());
common::SetEnv(kEnvSchedulerHost, host.c_str());
common::SetEnv(kEnvSchedulerPort, port.c_str());
PSContext::instance()->SetPSEnable(true);
EXPECT_EQ(300, PSContext::instance()->cluster_config().cluster_available_timeout);
}
} // namespace core
} // namespace ps

View File

@ -52,14 +52,6 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
EXPECT_TRUE(!ip.empty());
}
TEST_F(TestCommUtil, ValidateRankId) {
PSContext::instance()->cluster_config().Init(3, 2, "127.0.0.1", 9999);
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2));
}
TEST_F(TestCommUtil, Retry) {
bool const ret = CommUtil::Retry([]() -> bool { return false; }, 5, 100);
EXPECT_FALSE(ret);