forked from mindspore-Ecosystem/mindspore
scheduler added client
This commit is contained in:
parent
9f0d6ec8da
commit
54a331d103
|
@ -55,6 +55,10 @@ void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const
|
|||
}
|
||||
node_info_.rank_id_ = register_resp_message.rank_id();
|
||||
|
||||
// Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
|
||||
// scheduler is alive
|
||||
UpdateSchedulerTime();
|
||||
|
||||
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
|
||||
<< " registered scheduler success!";
|
||||
}
|
||||
|
@ -84,9 +88,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &mess
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_message) {
|
||||
on_node_event_message_ = on_node_event_message;
|
||||
}
|
||||
void AbstractNode::set_event_callback(const OnNodeEventMessage &event) { on_node_event_message_ = event; }
|
||||
|
||||
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||
int command, const uint32_t &timeout) {
|
||||
|
@ -211,16 +213,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
|||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
|
||||
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
|
||||
return ret;
|
||||
});
|
||||
message_tracker_.erase(request_id);
|
||||
return res;
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
@ -294,19 +286,19 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
|
|||
} else {
|
||||
UpdateSchedulerTime();
|
||||
}
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||
}
|
||||
});
|
||||
heart_beat_thread_->detach();
|
||||
}
|
||||
|
||||
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
|
||||
bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
|
||||
auto meta = std::make_shared<MessageMeta>();
|
||||
meta->set_cmd(NodeCommand::HEARTBEAT);
|
||||
|
||||
HeartbeatMessage heartbeat_message;
|
||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||
heartbeat_message.set_is_node_finish(is_node_finish);
|
||||
|
||||
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
|
||||
heartbeat_message.ByteSizeLong())) {
|
||||
|
@ -338,32 +330,21 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
|
|||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.ParseFromArray(data, size);
|
||||
|
||||
is_ready_ = heartbeat_resp_message.is_cluster_ready();
|
||||
if (is_ready_.load()) {
|
||||
current_cluster_state_ = heartbeat_resp_message.cluster_state();
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_READY) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is ready!";
|
||||
}
|
||||
if (heartbeat_resp_message.is_cluster_finish()) {
|
||||
Heartbeat(client_to_scheduler_, true);
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " is finish!";
|
||||
}
|
||||
is_timeout_ = heartbeat_resp_message.is_cluster_timeout();
|
||||
if (is_timeout_ && on_node_event_message_) {
|
||||
if (current_cluster_state_ == ClusterState::CLUSTER_TIMEOUT && on_node_event_message_) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
on_node_event_message_(NodeEvent::CLUSTER_TIMEOUT);
|
||||
}
|
||||
|
||||
if (heartbeat_resp_message.is_node_timeout() && on_node_event_message_) {
|
||||
on_node_event_message_(NodeEvent::NODE_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
|
||||
auto meta = std::make_shared<MessageMeta>();
|
||||
meta->set_cmd(NodeCommand::FETCH_SERVER);
|
||||
meta->set_cmd(NodeCommand::FETCH_METADATA);
|
||||
|
||||
FetchServersMessage fetch_servers;
|
||||
fetch_servers.set_node_id(node_info_.node_id_);
|
||||
|
@ -379,11 +360,39 @@ void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, co
|
|||
FetchServersRespMessage fetch_servers_resp_message;
|
||||
fetch_servers_resp_message.ParseFromArray(data, size);
|
||||
|
||||
nodes_address_.clear();
|
||||
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "The all server host size is:" << nodes_address_.size();
|
||||
void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
SendMetadataMessage send_meta_message;
|
||||
send_meta_message.ParseFromArray(data, size);
|
||||
nodes_address_.clear();
|
||||
MS_LOG(ERROR) << "send metadata size:" << send_meta_message.servers_meta().size();
|
||||
for (const auto &it : send_meta_message.servers_meta()) {
|
||||
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
|
||||
}
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
||||
bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
|
||||
|
@ -478,42 +487,6 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
|
|||
}
|
||||
}
|
||||
|
||||
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
meta->set_request_id(request_id);
|
||||
client->SendMessage(meta, protos, data, size);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return request_id;
|
||||
}
|
||||
|
||||
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
meta->set_request_id(request_id);
|
||||
client->SendMessage(meta, protos, data, size);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -570,14 +543,6 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
|
|||
message_callbacks_[request_id] = callback;
|
||||
}
|
||||
|
||||
void AbstractNode::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
uint64_t request_id = meta->request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
|
||||
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -639,20 +604,25 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
|
|||
void AbstractNode::InitCommandHandler() {
|
||||
handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp;
|
||||
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp;
|
||||
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
|
||||
handlers_[NodeCommand::FETCH_METADATA] = &AbstractNode::ProcessFetchServersResp;
|
||||
handlers_[NodeCommand::FINISH] = nullptr;
|
||||
}
|
||||
|
||||
uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(expected_response, 0);
|
||||
return request_id;
|
||||
void AbstractNode::InitServerHandler() {
|
||||
server_handler_[NodeCommand::SEND_METADATA] = &AbstractNode::ProcessSendMetadata;
|
||||
server_handler_[NodeCommand::FINISH] = &AbstractNode::ProcessFinish;
|
||||
server_handler_[NodeCommand::SEND_DATA] = nullptr;
|
||||
server_handler_[NodeCommand::COLLECTIVE_SEND_DATA] = nullptr;
|
||||
}
|
||||
|
||||
bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
|
||||
void AbstractNode::InitNode(const NodeRole &role) {
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = role;
|
||||
node_info_.ip_ = server_->BoundIp();
|
||||
node_info_.port_ = server_->BoundPort();
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " is generate uuid is:" << node_info_.node_id_ << ", the ip:" << server_->BoundIp()
|
||||
<< ", the port:" << server_->BoundPort();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -33,17 +33,25 @@ namespace ps {
|
|||
namespace core {
|
||||
class AbstractNode : public Node {
|
||||
public:
|
||||
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
|
||||
AbstractNode()
|
||||
: heart_beat_thread_(nullptr),
|
||||
client_to_scheduler_thread_(nullptr),
|
||||
client_to_scheduler_(nullptr),
|
||||
server_(nullptr),
|
||||
server_thread_(nullptr) {}
|
||||
~AbstractNode() override = default;
|
||||
|
||||
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
typedef void (AbstractNode::*ServerHandler)(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
using DataPtr = std::shared_ptr<unsigned char[]>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
|
||||
bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
void set_event_callback(const OnNodeEventMessage &on_node_event_message);
|
||||
|
||||
void set_event_callback(const OnNodeEventMessage &event);
|
||||
|
||||
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
@ -54,7 +62,6 @@ class AbstractNode : public Node {
|
|||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
||||
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
|
||||
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
|
||||
|
@ -63,13 +70,18 @@ class AbstractNode : public Node {
|
|||
|
||||
protected:
|
||||
void Register(const std::shared_ptr<TcpClient> &client);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
|
||||
bool Heartbeat(const std::shared_ptr<TcpClient> &client);
|
||||
void FetchServers(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
void ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
void ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
|
||||
const void *data, size_t size);
|
||||
|
||||
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
|
||||
void UpdateSchedulerTime();
|
||||
bool CheckSchedulerTimeout() const;
|
||||
|
@ -77,39 +89,29 @@ class AbstractNode : public Node {
|
|||
bool WaitForDisconnect(const uint32_t &timeout);
|
||||
bool InitClientToScheduler();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
|
||||
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
void RunMessageCallback(const uint64_t &request_id);
|
||||
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
|
||||
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
|
||||
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
|
||||
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
|
||||
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
|
||||
void InitCommandHandler();
|
||||
uint64_t AddMessageTrack(const uint32_t &expected_response);
|
||||
bool CheckMessageTrack(const uint64_t &request_id);
|
||||
void InitServerHandler();
|
||||
void InitNode(const NodeRole &role);
|
||||
|
||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||
std::shared_ptr<TcpClient> client_to_scheduler_;
|
||||
|
||||
OnNodeEventMessage on_node_event_message_;
|
||||
|
||||
// the key is: <node_role,rank_id>, the value is: <ip, port>
|
||||
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
||||
std::mutex client_mutex_;
|
||||
// the map's key is: rank_id
|
||||
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
|
||||
// the key is: request_id, the value is: <expected responses, actual responses>
|
||||
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
|
||||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
|
||||
// the key is: request_id, the value is: <rank_id, RecvMessage>
|
||||
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
|
||||
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
|
||||
|
@ -132,6 +134,13 @@ class AbstractNode : public Node {
|
|||
std::mutex rank_request_ids_mutex;
|
||||
timeval scheduler_time_{0, 0};
|
||||
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
|
||||
std::unordered_map<NodeCommand, ServerHandler> server_handler_;
|
||||
|
||||
std::unordered_map<NodeEvent, bool> is_event_send_;
|
||||
std::mutex is_event_send_mutex_;
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -31,24 +31,17 @@ namespace core {
|
|||
* Configuration information read through environment variables and configuration files, generally immutable
|
||||
*/
|
||||
struct ClusterConfig {
|
||||
ClusterConfig()
|
||||
: initial_worker_num(0),
|
||||
initial_server_num(0),
|
||||
explicit ClusterConfig(const uint32_t &worker_num, const uint32_t &server_num, std::string host, const uint16_t &port)
|
||||
: initial_worker_num(worker_num),
|
||||
initial_server_num(server_num),
|
||||
heartbeat_interval(3),
|
||||
scheduler_host(""),
|
||||
scheduler_port(0),
|
||||
scheduler_host(host),
|
||||
scheduler_port(port),
|
||||
heartbeat_timeout(30),
|
||||
cluster_available_timeout(300),
|
||||
connect_interval(100),
|
||||
scheduler_timeout(30) {}
|
||||
|
||||
void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string host, const uint16_t &port) {
|
||||
initial_worker_num = worker_num;
|
||||
initial_server_num = server_num;
|
||||
scheduler_host = host;
|
||||
scheduler_port = port;
|
||||
}
|
||||
|
||||
// Configure through environment variables:MS_WORKER_NUM
|
||||
uint32_t initial_worker_num;
|
||||
// Configure through environment variables:MS_SERVER_NUM
|
||||
|
|
|
@ -31,15 +31,10 @@ namespace core {
|
|||
* The metadata information of the cluster, stored in the scheduler, is generally used for scale out and scale in.
|
||||
*/
|
||||
struct ClusterMetadata {
|
||||
ClusterMetadata() : worker_num_(0), server_num_(0) {}
|
||||
ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {}
|
||||
|
||||
void Init(const uint32_t &worker_num, const uint32_t &server_num) {
|
||||
worker_num_ = worker_num;
|
||||
server_num_ = server_num;
|
||||
}
|
||||
|
||||
uint32_t worker_num_;
|
||||
uint32_t server_num_;
|
||||
uint32_t worker_num;
|
||||
uint32_t server_num;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -51,8 +51,8 @@
|
|||
#include "proto/ps.pb.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
|
@ -30,7 +30,7 @@ enum class Command {
|
|||
REGISTER = 1,
|
||||
HEARTBEAT = 2,
|
||||
SEND_DATA = 3,
|
||||
FETCH_SERVER = 4,
|
||||
FETCH_METADATA = 4,
|
||||
FINISH = 5,
|
||||
COLLECTIVE_SEND_DATA = 6
|
||||
};
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_REQUEST_PROCESS_RESULT_CODE_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_REQUEST_PROCESS_RESULT_CODE_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -99,4 +99,4 @@ class RequestProcessResult {
|
|||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_STATUS_H_
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_REQUEST_PROCESS_RESULT_CODE_H_
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_CLIENT_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_CLIENT_H_
|
||||
|
||||
#include "ps/core/communicator/tcp_message_handler.h"
|
||||
|
||||
#include <event2/event.h>
|
||||
#include <event2/bufferevent.h>
|
||||
#include <event2/thread.h>
|
||||
|
@ -33,13 +31,13 @@
|
|||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
#include "ps/core/communicator/ssl_wrapper.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/tcp_message_handler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
|
|
@ -37,7 +37,6 @@
|
|||
|
||||
#include "ps/core/communicator/tcp_message_handler.h"
|
||||
#include "ps/core/communicator/ssl_wrapper.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
|
|
|
@ -36,6 +36,72 @@ bool Node::WaitForStart(const uint32_t &timeout) {
|
|||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout) {
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
|
||||
client->SendMessage(message);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
meta->set_request_id(request_id);
|
||||
client->SendMessage(meta, protos, data, size);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return request_id;
|
||||
}
|
||||
|
||||
bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
uint64_t request_id = AddMessageTrack(1);
|
||||
meta->set_request_id(request_id);
|
||||
client->SendMessage(meta, protos, data, size);
|
||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
|
||||
return Wait(request_id, timeout);
|
||||
}
|
||||
|
||||
bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
|
||||
std::unique_lock<std::mutex> lock(message_tracker_mutex_);
|
||||
bool res = message_tracker_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
|
||||
bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
|
||||
return ret;
|
||||
});
|
||||
message_tracker_.erase(request_id);
|
||||
return res;
|
||||
}
|
||||
|
||||
uint64_t Node::AddMessageTrack(const uint32_t &expected_response) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
uint64_t request_id = ++next_request_id_;
|
||||
message_tracker_[request_id] = std::make_pair(expected_response, 0);
|
||||
return request_id;
|
||||
}
|
||||
|
||||
bool Node::CheckMessageTrack(const uint64_t &request_id) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
|
||||
}
|
||||
|
||||
void Node::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
|
||||
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
|
||||
uint64_t request_id = meta->request_id();
|
||||
|
||||
message_tracker_[request_id].second++;
|
||||
message_tracker_cond_.notify_all();
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,13 +47,20 @@ class Node {
|
|||
Node()
|
||||
: is_ready_(false),
|
||||
is_finish_(false),
|
||||
is_timeout_(false),
|
||||
is_node_ready_scale_out_(false),
|
||||
is_node_ready_scale_in_(false),
|
||||
is_cluster_ready_scale_out_(false),
|
||||
is_cluster_ready_scale_in_(false),
|
||||
update_local_servers_(false),
|
||||
is_already_stopped_(true),
|
||||
is_already_finished_(false),
|
||||
next_request_id_(0) {}
|
||||
next_request_id_(0),
|
||||
current_node_state_(NodeState::NODE_STARTING),
|
||||
current_cluster_state_(ClusterState::ClUSTER_STARTING) {}
|
||||
virtual ~Node() = default;
|
||||
|
||||
using OnNodeEventMessage = std::function<void(const NodeEvent &event)>;
|
||||
|
||||
using MessageCallback = std::function<void()>;
|
||||
|
||||
virtual bool Start(const uint32_t &timeout = PSContext::instance()->cluster_config().cluster_available_timeout) = 0;
|
||||
|
@ -64,13 +71,37 @@ class Node {
|
|||
uint32_t rank_id() const;
|
||||
NodeRole role() const;
|
||||
|
||||
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
protected:
|
||||
bool WaitForStart(const uint32_t &timeout);
|
||||
|
||||
// Send data synchronously
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
|
||||
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
// Send data asynchronously
|
||||
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size);
|
||||
|
||||
uint64_t AddMessageTrack(const uint32_t &expected_response);
|
||||
bool CheckMessageTrack(const uint64_t &request_id);
|
||||
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
|
||||
|
||||
NodeInfo node_info_;
|
||||
std::atomic<bool> is_ready_;
|
||||
std::atomic<bool> is_finish_;
|
||||
std::atomic<bool> is_timeout_;
|
||||
|
||||
std::atomic<bool> is_node_ready_scale_out_;
|
||||
std::atomic<bool> is_node_ready_scale_in_;
|
||||
|
||||
std::atomic<bool> is_cluster_ready_scale_out_;
|
||||
std::atomic<bool> is_cluster_ready_scale_in_;
|
||||
|
||||
// Determine whether to update the ip and port of the locally cached servers.
|
||||
std::atomic<bool> update_local_servers_;
|
||||
|
||||
std::atomic<bool> is_already_stopped_;
|
||||
std::atomic<bool> is_already_finished_;
|
||||
std::atomic_uint64_t next_request_id_;
|
||||
|
@ -80,6 +111,14 @@ class Node {
|
|||
std::mutex wait_finish_mutex_;
|
||||
std::condition_variable wait_finish_cond_;
|
||||
std::mutex finish_mutex_;
|
||||
|
||||
// the key is: request_id, the value is: <expected responses, actual responses>
|
||||
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
|
||||
std::mutex message_tracker_mutex_;
|
||||
std::condition_variable message_tracker_cond_;
|
||||
|
||||
NodeState current_node_state_;
|
||||
ClusterState current_cluster_state_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -25,7 +25,15 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 };
|
||||
enum class NodeEvent {
|
||||
CLUSTER_TIMEOUT = 0,
|
||||
NODE_TIMEOUT = 1,
|
||||
SCHEDULER_TIMEOUT = 2,
|
||||
READY_FOR_SCALE_OUT = 3,
|
||||
READY_FOR_SCALE_IN = 4,
|
||||
CLUSTER_SCALE_OUT_DONE = 5,
|
||||
CLUSTER_SCALE_IN_DONE = 6
|
||||
};
|
||||
|
||||
struct NodeInfo {
|
||||
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}
|
||||
|
|
|
@ -19,9 +19,12 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
void NodeManager::InitNodeNum() {
|
||||
total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
|
||||
PSContext::instance()->cluster_config().initial_worker_num;
|
||||
void NodeManager::InitNode() {
|
||||
initial_total_node_num_ = PSContext::instance()->cluster_config().initial_server_num +
|
||||
PSContext::instance()->cluster_config().initial_worker_num;
|
||||
meta_data_ = std::make_unique<ClusterMetadata>(PSContext::instance()->cluster_config().initial_worker_num,
|
||||
PSContext::instance()->cluster_config().initial_server_num);
|
||||
total_node_num_ = initial_total_node_num_;
|
||||
}
|
||||
|
||||
int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
||||
|
@ -40,8 +43,8 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
uint32_t port = register_message.port();
|
||||
|
||||
rank_id = ++next_server_rank_id_;
|
||||
if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_server_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of servers.";
|
||||
if (IntToUint(rank_id) >= meta_data_->server_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num;
|
||||
rank_id = -1;
|
||||
--next_server_rank_id_;
|
||||
}
|
||||
|
@ -55,9 +58,11 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port
|
||||
<< " assign rank id:" << rank_id;
|
||||
} else if (register_message.role() == NodeRole::WORKER) {
|
||||
const std::string &ip = register_message.ip();
|
||||
uint32_t port = register_message.port();
|
||||
rank_id = ++next_worker_rank_id_;
|
||||
if (IntToUint(rank_id) >= PSContext::instance()->cluster_config().initial_worker_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of workers.";
|
||||
if (IntToUint(rank_id) >= meta_data_->worker_num) {
|
||||
MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num;
|
||||
rank_id = -1;
|
||||
--next_worker_rank_id_;
|
||||
}
|
||||
|
@ -65,6 +70,8 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) {
|
|||
node_info.node_role_ = NodeRole::WORKER;
|
||||
node_info.node_id_ = node_id;
|
||||
node_info.rank_id_ = rank_id;
|
||||
node_info.ip_ = ip;
|
||||
node_info.port_ = port;
|
||||
nodes_info_[node_id] = node_info;
|
||||
MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id;
|
||||
}
|
||||
|
@ -81,9 +88,11 @@ void NodeManager::UpdateHeartbeat(const std::string &node_id) {
|
|||
<< ", the node rank id:" << node_info.rank_id_ << " the current time is: " << current_time.tv_sec;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateNodeFinishState(const std::string &node_id) { heartbeats_finish_nodes_.insert(node_id); }
|
||||
void NodeManager::UpdateNodeScaleInState(const std::string &node_id) { heartbeats_scale_in_nodes_.insert(node_id); }
|
||||
|
||||
bool NodeManager::CheckNodesFinishState() { return heartbeats_finish_nodes_.size() == nodes_info_.size(); }
|
||||
bool NodeManager::CheckNodesScaluOutState() { return SizeToInt(heartbeats_scale_out_nodes_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::CheckNodesScaleInState() { return SizeToInt(heartbeats_scale_in_nodes_.size()) == total_node_num_; }
|
||||
|
||||
std::vector<ServersMeta> NodeManager::FetchServersMeta() {
|
||||
std::vector<ServersMeta> servers_meta_list;
|
||||
|
@ -99,7 +108,7 @@ std::vector<ServersMeta> NodeManager::FetchServersMeta() {
|
|||
return servers_meta_list;
|
||||
}
|
||||
|
||||
void NodeManager::UpdateClusterState() {
|
||||
void NodeManager::UpdateCluster() {
|
||||
// 1. update cluster timeout state
|
||||
struct timeval current_time {};
|
||||
(void)gettimeofday(¤t_time, nullptr);
|
||||
|
@ -111,48 +120,61 @@ void NodeManager::UpdateClusterState() {
|
|||
}
|
||||
}
|
||||
if (!timeout_nodes_info_.empty()) {
|
||||
is_node_timeout_ = true;
|
||||
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
|
||||
for (auto it = timeout_nodes_info_.begin(); it != timeout_nodes_info_.end(); ++it) {
|
||||
finish_nodes_id_.insert(it->first);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. update cluster finish state
|
||||
if (finish_nodes_id_.size() == total_node_num_ || SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
|
||||
is_cluster_finish_ = true;
|
||||
is_cluster_ready_ = true;
|
||||
}
|
||||
|
||||
// 3. update cluster ready state
|
||||
if (nodes_info_.size() == total_node_num_) {
|
||||
is_cluster_ready_ = true;
|
||||
if (SizeToInt(finish_nodes_id_.size()) == total_node_num_ ||
|
||||
SizeToInt(finish_nodes_id_.size()) == current_node_num_) {
|
||||
UpdateClusterState(ClusterState::CLUSTER_FINISH);
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::CheckClusterTimeout() {
|
||||
if (total_node_num_ != nodes_info_.size()) {
|
||||
if (total_node_num_ != SizeToInt(nodes_info_.size())) {
|
||||
MS_LOG(WARNING) << "The cluster is not ready after "
|
||||
<< PSContext::instance()->cluster_config().cluster_available_timeout
|
||||
<< " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to "
|
||||
<< nodes_info_.size();
|
||||
current_node_num_ = nodes_info_.size();
|
||||
is_cluster_timeout_ = true;
|
||||
UpdateClusterState(ClusterState::CLUSTER_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::AddFinishNode(const std::string &finish_message) { finish_nodes_id_.insert(finish_message); }
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> NodeManager::nodes_info() { return nodes_info_; }
|
||||
bool NodeManager::CheckRegisterNum() { return SizeToInt(nodes_info_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::is_cluster_finish() { return is_cluster_finish_.load(); }
|
||||
bool NodeManager::CheckFinishNum() { return SizeToInt(finish_nodes_id_.size()) == total_node_num_; }
|
||||
|
||||
bool NodeManager::is_cluster_ready() { return is_cluster_ready_.load(); }
|
||||
std::unordered_map<std::string, NodeInfo> &NodeManager::nodes_info() { return nodes_info_; }
|
||||
|
||||
bool NodeManager::is_cluster_timeout() { return is_cluster_timeout_.load(); }
|
||||
void NodeManager::UpdateNodeState(const NodeState &state) {
|
||||
std::lock_guard<std::mutex> lk(node_mutex_);
|
||||
node_state_ = state;
|
||||
}
|
||||
|
||||
bool NodeManager::is_node_timeout() { return is_node_timeout_.load(); }
|
||||
void NodeManager::UpdateClusterState(const ClusterState &state) {
|
||||
std::lock_guard<std::mutex> lk(cluster_mutex_);
|
||||
cluster_state_ = state;
|
||||
}
|
||||
|
||||
void NodeManager::set_cluster_timeout(bool is_cluster_timeout) { is_cluster_timeout_ = is_cluster_timeout; }
|
||||
NodeState NodeManager::GetNodeState() {
|
||||
std::lock_guard<std::mutex> lk(node_mutex_);
|
||||
return node_state_;
|
||||
}
|
||||
|
||||
ClusterState NodeManager::GetClusterState() {
|
||||
std::lock_guard<std::mutex> lk(cluster_mutex_);
|
||||
return cluster_state_;
|
||||
}
|
||||
|
||||
void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
|
||||
|
||||
const int32_t &NodeManager::total_node_num() { return total_node_num_; }
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "ps/core/node.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -41,52 +42,75 @@ namespace core {
|
|||
class NodeManager {
|
||||
public:
|
||||
NodeManager()
|
||||
: is_cluster_ready_(false),
|
||||
is_cluster_finish_(false),
|
||||
is_cluster_timeout_(false),
|
||||
is_node_timeout_(false),
|
||||
total_node_num_(0),
|
||||
: initial_total_node_num_(0),
|
||||
total_node_num_(-1),
|
||||
current_node_num_(-1),
|
||||
next_worker_rank_id_(-1),
|
||||
next_server_rank_id_(-1) {}
|
||||
next_server_rank_id_(-1),
|
||||
meta_data_(nullptr),
|
||||
node_state_(NodeState::NODE_STARTING),
|
||||
cluster_state_(ClusterState::ClUSTER_STARTING) {}
|
||||
virtual ~NodeManager() = default;
|
||||
|
||||
enum ClusterState { STARTING, STARTED, FAILED, STOPPING, STOPPED };
|
||||
|
||||
void InitNodeNum();
|
||||
// When initializing nodes, the initial number of nodes will be assigned to the total number of nodes.
|
||||
void InitNode();
|
||||
int NextRankId(const RegisterMessage ®ister_message);
|
||||
|
||||
void UpdateHeartbeat(const std::string &node_id);
|
||||
void UpdateNodeFinishState(const std::string &node_id);
|
||||
bool CheckNodesFinishState();
|
||||
bool CheckNodesScaluOutState();
|
||||
void UpdateNodeScaleInState(const std::string &node_id);
|
||||
bool CheckNodesScaleInState();
|
||||
|
||||
std::vector<ServersMeta> FetchServersMeta();
|
||||
void UpdateClusterState();
|
||||
void UpdateCluster();
|
||||
void CheckClusterTimeout();
|
||||
void AddFinishNode(const std::string &finish_message);
|
||||
std::unordered_map<std::string, NodeInfo> nodes_info();
|
||||
bool is_cluster_ready();
|
||||
bool is_cluster_finish();
|
||||
bool is_cluster_timeout();
|
||||
bool is_node_timeout();
|
||||
void set_cluster_timeout(bool is_cluster_timeout);
|
||||
|
||||
// When workers and servers registered to scheduler, the scheduler will collect the number of registered
|
||||
// nodes and Determine whether the registered number of worker and server is equal to total_node_num_.
|
||||
bool CheckRegisterNum();
|
||||
// When workers and servers send a finish message to the scheduler, the scheduler will collect the number of
|
||||
// finish nodes and Determine whether the finished nodes are equal to total_node_num_.
|
||||
bool CheckFinishNum();
|
||||
|
||||
std::unordered_map<std::string, NodeInfo> &nodes_info();
|
||||
|
||||
void set_total_node_num(const int32_t &node_num);
|
||||
const int32_t &total_node_num();
|
||||
|
||||
void UpdateNodeState(const NodeState &state);
|
||||
void UpdateClusterState(const ClusterState &state);
|
||||
NodeState GetNodeState();
|
||||
ClusterState GetClusterState();
|
||||
|
||||
private:
|
||||
std::atomic<bool> is_cluster_ready_;
|
||||
std::atomic<bool> is_cluster_finish_;
|
||||
std::atomic<bool> is_cluster_timeout_;
|
||||
std::atomic<bool> is_node_timeout_;
|
||||
uint32_t total_node_num_;
|
||||
std::mutex node_mutex_;
|
||||
std::mutex cluster_mutex_;
|
||||
|
||||
uint32_t initial_total_node_num_;
|
||||
int32_t total_node_num_;
|
||||
int32_t current_node_num_;
|
||||
|
||||
std::atomic<int> next_worker_rank_id_;
|
||||
std::atomic<int> next_server_rank_id_;
|
||||
|
||||
// worker nodes and server nodes
|
||||
std::unordered_map<std::string, NodeInfo> nodes_info_;
|
||||
std::mutex assign_rank_id_mutex_;
|
||||
std::mutex heartbeat_mutex_;
|
||||
|
||||
std::unordered_map<std::string, timeval> heartbeats_;
|
||||
std::unordered_set<std::string> heartbeats_finish_nodes_;
|
||||
std::unordered_set<std::string> heartbeats_scale_out_nodes_;
|
||||
std::unordered_set<std::string> heartbeats_scale_in_nodes_;
|
||||
// timeout nodes
|
||||
std::unordered_map<std::string, NodeInfo> timeout_nodes_info_;
|
||||
std::unordered_set<std::string> finish_nodes_id_;
|
||||
|
||||
std::unique_ptr<ClusterMetadata> meta_data_;
|
||||
|
||||
NodeState node_state_;
|
||||
ClusterState cluster_state_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -23,9 +23,12 @@ enum NodeCommand {
|
|||
REGISTER = 1;
|
||||
HEARTBEAT = 2;
|
||||
SEND_DATA = 3;
|
||||
FETCH_SERVER = 4;
|
||||
// The worker or server asks the scheduler for metadata
|
||||
FETCH_METADATA = 4;
|
||||
FINISH = 5;
|
||||
COLLECTIVE_SEND_DATA = 6;
|
||||
// The scheduler actively sends metadata to the worker and server
|
||||
SEND_METADATA = 7;
|
||||
}
|
||||
|
||||
enum NodeRole {
|
||||
|
@ -66,15 +69,27 @@ message RegisterRespMessage {
|
|||
message HeartbeatMessage {
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 1;
|
||||
bool is_node_finish = 2;
|
||||
}
|
||||
|
||||
enum NodeState {
|
||||
NODE_STARTING = 0;
|
||||
NODE_FINISH = 1;
|
||||
NODE_READY = 2;
|
||||
NODE_TIMEOUT = 3;
|
||||
}
|
||||
|
||||
enum ClusterState {
|
||||
ClUSTER_STARTING = 0;
|
||||
CLUSTER_READY = 1;
|
||||
CLUSTER_FINISH = 2;
|
||||
CLUSTER_TIMEOUT = 3;
|
||||
CLUSTER_SCALE_OUT = 4;
|
||||
CLUSTER_SCALE_IN = 5;
|
||||
CLUSTER_FAILURE = 6;
|
||||
}
|
||||
|
||||
message HeartbeatRespMessage {
|
||||
// Is the entire system ready to use.
|
||||
bool is_cluster_ready = 1;
|
||||
bool is_cluster_finish = 2;
|
||||
bool is_cluster_timeout = 3;
|
||||
bool is_node_timeout = 4;
|
||||
ClusterState cluster_state = 1;
|
||||
}
|
||||
|
||||
message FetchServersMessage {
|
||||
|
@ -92,6 +107,10 @@ message ServersMeta {
|
|||
|
||||
}
|
||||
|
||||
message SendMetadataMessage {
|
||||
repeated ServersMeta servers_meta = 1;
|
||||
}
|
||||
|
||||
message FinishMessage {
|
||||
// the current Node unique id:0,1,2...
|
||||
string node_id = 1;
|
||||
|
|
|
@ -47,21 +47,9 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
|
|||
|
||||
node_manager_.UpdateHeartbeat(heartbeat_message.node_id());
|
||||
|
||||
if (heartbeat_message.is_node_finish()) {
|
||||
node_manager_.UpdateNodeFinishState(heartbeat_message.node_id());
|
||||
}
|
||||
|
||||
if (heartbeat_message.is_node_finish() && node_manager_.CheckNodesFinishState()) {
|
||||
MS_LOG(INFO) << "The scheduler node receive all the finish cmd!";
|
||||
is_finish_ = true;
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
|
||||
HeartbeatRespMessage heartbeat_resp_message;
|
||||
heartbeat_resp_message.set_is_cluster_ready(node_manager_.is_cluster_ready());
|
||||
heartbeat_resp_message.set_is_cluster_finish(node_manager_.is_cluster_finish());
|
||||
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
|
||||
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());
|
||||
|
||||
heartbeat_resp_message.set_cluster_state(node_manager_.GetClusterState());
|
||||
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
|
||||
heartbeat_resp_message.ByteSizeLong());
|
||||
|
@ -81,11 +69,11 @@ void SchedulerNode::InitCommandHandler() {
|
|||
handlers_[NodeCommand::HEARTBEAT] = &SchedulerNode::ProcessHeartbeat;
|
||||
handlers_[NodeCommand::REGISTER] = &SchedulerNode::ProcessRegister;
|
||||
handlers_[NodeCommand::FINISH] = &SchedulerNode::ProcessFinish;
|
||||
handlers_[NodeCommand::FETCH_SERVER] = &SchedulerNode::ProcessFetchServers;
|
||||
handlers_[NodeCommand::FETCH_METADATA] = &SchedulerNode::ProcessFetchMetadata;
|
||||
}
|
||||
|
||||
void SchedulerNode::CreateTcpServer() {
|
||||
node_manager_.InitNodeNum();
|
||||
node_manager_.InitNode();
|
||||
|
||||
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
|
||||
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
|
||||
|
@ -131,6 +119,17 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
|
|||
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
|
||||
register_resp_message.ByteSizeLong());
|
||||
|
||||
if (node_manager_.CheckRegisterNum()) {
|
||||
is_ready_ = true;
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
SendMetadata(client);
|
||||
}
|
||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
|
@ -143,10 +142,20 @@ void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared
|
|||
node_manager_.AddFinishNode(*finish_message);
|
||||
MS_LOG(INFO) << "Process finish message from node id:" << *finish_message;
|
||||
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
|
||||
if (node_manager_.CheckFinishNum()) {
|
||||
auto node_infos = node_manager_.nodes_info();
|
||||
for (const auto &kvs : node_infos) {
|
||||
auto client = GetOrCreateClient(kvs.second);
|
||||
SendFinish(client);
|
||||
}
|
||||
is_finish_ = true;
|
||||
current_cluster_state_ = ClusterState::CLUSTER_FINISH;
|
||||
wait_finish_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
void SchedulerNode::ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(server);
|
||||
MS_EXCEPTION_IF_NULL(conn);
|
||||
MS_EXCEPTION_IF_NULL(meta);
|
||||
|
@ -160,26 +169,59 @@ void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::
|
|||
fetch_servers_message.ByteSizeLong());
|
||||
}
|
||||
|
||||
void SchedulerNode::SendMetadata(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::SEND_METADATA);
|
||||
|
||||
SendMetadataMessage send_metadata_message;
|
||||
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();
|
||||
|
||||
MS_LOG(ERROR) << "the list size:" << servers_meta_list.size();
|
||||
|
||||
*send_metadata_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
|
||||
|
||||
if (!SendMessageAsync(client, message_meta, Protos::PROTOBUF, send_metadata_message.SerializeAsString().data(),
|
||||
send_metadata_message.ByteSizeLong())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send metadata timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << "is sending metadata to workers and servers!";
|
||||
}
|
||||
|
||||
void SchedulerNode::SendFinish(const std::shared_ptr<TcpClient> &client) {
|
||||
MS_EXCEPTION_IF_NULL(client);
|
||||
auto message_meta = std::make_shared<MessageMeta>();
|
||||
message_meta->set_cmd(NodeCommand::FINISH);
|
||||
|
||||
// The scheduler does not need to bring any data when sending the finish command
|
||||
std::string resp_data;
|
||||
|
||||
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, resp_data.data(), resp_data.size())) {
|
||||
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << " send finish timeout!";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " the node id:" << node_info_.node_id_ << "is sending finish to workers and servers!";
|
||||
}
|
||||
|
||||
void SchedulerNode::StartUpdateClusterStateTimer() {
|
||||
MS_LOG(WARNING) << "The scheduler start a heartbeat timer!";
|
||||
update_state_thread_ = std::make_unique<std::thread>([&]() {
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
while (!is_finish_.load()) {
|
||||
// 1. update cluster timeout
|
||||
if (!node_manager_.is_cluster_ready() &&
|
||||
(std::chrono::steady_clock::now() - start_time >
|
||||
std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
|
||||
if (!is_ready_ && (std::chrono::steady_clock::now() - start_time >
|
||||
std::chrono::seconds(PSContext::instance()->cluster_config().cluster_available_timeout))) {
|
||||
node_manager_.CheckClusterTimeout();
|
||||
}
|
||||
|
||||
// 2. update cluster state
|
||||
std::this_thread::sleep_for(std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval));
|
||||
node_manager_.UpdateClusterState();
|
||||
if (node_manager_.is_cluster_ready()) {
|
||||
is_ready_ = true;
|
||||
wait_start_cond_.notify_all();
|
||||
}
|
||||
if (node_manager_.is_cluster_finish()) {
|
||||
node_manager_.UpdateCluster();
|
||||
|
||||
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_FINISH) {
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::seconds(PSContext::instance()->cluster_config().heartbeat_interval * 2));
|
||||
is_finish_ = true;
|
||||
|
@ -189,6 +231,29 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
|
|||
});
|
||||
}
|
||||
|
||||
const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInfo &node_info) {
|
||||
if (connected_nodes_.count(node_info.node_id_)) {
|
||||
return connected_nodes_[node_info.node_id_];
|
||||
} else {
|
||||
std::string ip = node_info.ip_;
|
||||
uint16_t port = node_info.port_;
|
||||
auto client = std::make_shared<TcpClient>(ip, port);
|
||||
MS_LOG(ERROR) << "the ip:" << node_info.ip_ << ", the port:" << node_info.port_;
|
||||
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
|
||||
size_t size) { NotifyMessageArrival(meta); });
|
||||
client->Init();
|
||||
if (is_client_started_ == false) {
|
||||
is_client_started_ = true;
|
||||
client_thread_ = std::make_unique<std::thread>([&]() {
|
||||
MS_LOG(INFO) << "The node start a tcp client!";
|
||||
client->Start();
|
||||
});
|
||||
}
|
||||
connected_nodes_[node_info.node_id_] = client;
|
||||
return connected_nodes_[node_info.node_id_];
|
||||
}
|
||||
}
|
||||
|
||||
bool SchedulerNode::Stop() {
|
||||
MS_LOG(INFO) << "Stop scheduler node!";
|
||||
if (!is_already_stopped_) {
|
||||
|
@ -196,6 +261,12 @@ bool SchedulerNode::Stop() {
|
|||
update_state_thread_->join();
|
||||
server_->Stop();
|
||||
scheduler_thread_->join();
|
||||
if (!connected_nodes_.empty()) {
|
||||
for (auto &connected_node : connected_nodes_) {
|
||||
connected_node.second->Stop();
|
||||
}
|
||||
}
|
||||
client_thread_->join();
|
||||
is_ready_ = true;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -27,20 +27,31 @@
|
|||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/node_manager.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "ps/core/communicator/request_process_result_code.h"
|
||||
#include "ps/core/communicator/http_message_handler.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/communicator/http_server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class SchedulerNode : public Node {
|
||||
public:
|
||||
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
|
||||
SchedulerNode()
|
||||
: server_(nullptr),
|
||||
scheduler_thread_(nullptr),
|
||||
update_state_thread_(nullptr),
|
||||
restful_thread_(nullptr),
|
||||
http_server_(nullptr),
|
||||
client_thread_(nullptr),
|
||||
is_client_started_(false) {}
|
||||
~SchedulerNode() override;
|
||||
|
||||
typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
|
@ -54,15 +65,22 @@ class SchedulerNode : public Node {
|
|||
void Initialize();
|
||||
void InitCommandHandler();
|
||||
void CreateTcpServer();
|
||||
void StartUpdateClusterStateTimer();
|
||||
const std::shared_ptr<TcpClient> &GetOrCreateClient(const NodeInfo &node_info);
|
||||
|
||||
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void StartUpdateClusterStateTimer();
|
||||
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
void ProcessFetchMetadata(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
|
||||
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
// After scheduler collects all registered message, it actively sends metadata to workers and servers.
|
||||
void SendMetadata(const std::shared_ptr<TcpClient> &client);
|
||||
// // After scheduler collects all finish message, it actively sends finish message to workers and servers.
|
||||
void SendFinish(const std::shared_ptr<TcpClient> &client);
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> scheduler_thread_;
|
||||
|
@ -70,6 +88,14 @@ class SchedulerNode : public Node {
|
|||
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
|
||||
|
||||
NodeManager node_manager_;
|
||||
|
||||
std::unique_ptr<std::thread> restful_thread_;
|
||||
std::shared_ptr<HttpServer> http_server_;
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<TcpClient>> connected_nodes_;
|
||||
|
||||
std::unique_ptr<std::thread> client_thread_;
|
||||
std::atomic<bool> is_client_started_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -32,11 +32,6 @@ bool ServerNode::Start(const uint32_t &timeout) {
|
|||
}
|
||||
MS_LOG(INFO) << "The cluster is ready to use!";
|
||||
|
||||
// If the cluster is ready to use, then Get the address of all the servers
|
||||
if (!is_timeout_.load()) {
|
||||
FetchServers(client_to_scheduler_);
|
||||
MS_LOG(INFO) << "Server node get all the servers address successful!";
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
MS_LOG(INFO) << "Start the node is successful!";
|
||||
return true;
|
||||
|
@ -63,34 +58,32 @@ void ServerNode::CreateTcpServer() {
|
|||
server_ = std::make_shared<TcpServer>(server_ip, 0);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
switch (meta->cmd()) {
|
||||
case NodeCommand::SEND_DATA:
|
||||
ProcessSendData(conn, meta, protos, data, size);
|
||||
break;
|
||||
case NodeCommand::COLLECTIVE_SEND_DATA:
|
||||
ProcessCollectiveSendData(conn, meta, data, size);
|
||||
RunReceiveCallback(meta, protos, data, size);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
if (server_handler_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
|
||||
if (meta->cmd() == NodeCommand::COLLECTIVE_SEND_DATA) {
|
||||
ProcessCollectiveSendData(conn, meta, data, size);
|
||||
RunReceiveCallback(meta, protos, data, size);
|
||||
} else if (meta->cmd() == NodeCommand::SEND_DATA) {
|
||||
ProcessSendData(conn, meta, protos, data, size);
|
||||
} else {
|
||||
const auto &handler_ptr = server_handler_[meta->cmd()];
|
||||
(this->*handler_ptr)(conn, meta, protos, data, size);
|
||||
}
|
||||
});
|
||||
server_->Init();
|
||||
server_thread_ = std::make_unique<std::thread>([&]() {
|
||||
server_thread_ = std::make_unique<std::thread>([this]() {
|
||||
MS_LOG(INFO) << "The server node start a tcp server!";
|
||||
server_->Start();
|
||||
this->server_->Start();
|
||||
});
|
||||
}
|
||||
|
||||
void ServerNode::Initialize() {
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = NodeRole::SERVER;
|
||||
node_info_.ip_ = server_->BoundIp();
|
||||
node_info_.port_ = server_->BoundPort();
|
||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< " is generate uuid is:" << node_info_.node_id_;
|
||||
InitNode(NodeRole::SERVER);
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Server node init client timeout!";
|
||||
|
|
|
@ -43,7 +43,7 @@ constexpr char kHttpCommunicator[] = "HTTP";
|
|||
|
||||
class ServerNode : public AbstractNode {
|
||||
public:
|
||||
ServerNode() : server_(nullptr), server_thread_(nullptr) {}
|
||||
ServerNode() = default;
|
||||
|
||||
~ServerNode() override = default;
|
||||
|
||||
|
@ -71,8 +71,6 @@ class ServerNode : public AbstractNode {
|
|||
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const void *data, size_t size);
|
||||
|
||||
std::shared_ptr<TcpServer> server_;
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
RequestHandler request_handler_;
|
||||
std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
|
||||
std::mutex communicator_mutex_;
|
||||
|
|
|
@ -32,11 +32,6 @@ bool WorkerNode::Start(const uint32_t &timeout) {
|
|||
}
|
||||
MS_LOG(INFO) << "The node is ready to fetch servers!";
|
||||
|
||||
// If the cluster is ready to use, then Get the address of all the servers
|
||||
if (!is_timeout_.load()) {
|
||||
FetchServers(client_to_scheduler_);
|
||||
MS_LOG(INFO) << "Worker node get all the servers address successful!";
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
MS_LOG(INFO) << "The Worker node has successfully started.";
|
||||
return true;
|
||||
|
@ -44,10 +39,9 @@ bool WorkerNode::Start(const uint32_t &timeout) {
|
|||
|
||||
void WorkerNode::Initialize() {
|
||||
is_already_stopped_ = false;
|
||||
node_info_.node_id_ = CommUtil::GenerateUUID();
|
||||
node_info_.node_role_ = NodeRole::WORKER;
|
||||
MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||
<< ", the node id is:" << node_info_.node_id_;
|
||||
InitServerHandler();
|
||||
CreateTcpServer();
|
||||
InitNode(NodeRole::WORKER);
|
||||
InitCommandHandler();
|
||||
if (!InitClientToScheduler()) {
|
||||
MS_LOG(EXCEPTION) << "Worker node init client timeout!";
|
||||
|
@ -55,11 +49,30 @@ void WorkerNode::Initialize() {
|
|||
MS_LOG(INFO) << "Worker node init client successful!";
|
||||
}
|
||||
|
||||
void WorkerNode::CreateTcpServer() {
|
||||
std::string interface;
|
||||
std::string server_ip;
|
||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
||||
server_ = std::make_shared<TcpServer>(server_ip, 0);
|
||||
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
|
||||
const Protos &protos, const void *data, size_t size) {
|
||||
if (server_handler_.count(meta->cmd()) == 0) {
|
||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||
}
|
||||
const auto &handler_ptr = server_handler_[meta->cmd()];
|
||||
(this->*handler_ptr)(conn, meta, protos, data, size);
|
||||
});
|
||||
server_->Init();
|
||||
server_thread_ = std::make_unique<std::thread>([&]() {
|
||||
MS_LOG(INFO) << "The worker node start a tcp server!";
|
||||
server_->Start();
|
||||
});
|
||||
}
|
||||
|
||||
bool WorkerNode::Stop() {
|
||||
if (!is_already_stopped_.load()) {
|
||||
MS_LOG(INFO) << "Stop worker node!";
|
||||
is_ready_ = true;
|
||||
is_timeout_ = true;
|
||||
is_finish_ = true;
|
||||
client_to_scheduler_->Stop();
|
||||
if (!connected_nodes_.empty()) {
|
||||
|
@ -67,6 +80,8 @@ bool WorkerNode::Stop() {
|
|||
connected_node.second->Stop();
|
||||
}
|
||||
}
|
||||
server_->Stop();
|
||||
server_thread_->join();
|
||||
is_already_stopped_ = true;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -24,9 +24,7 @@
|
|||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/communicator/tcp_client.h"
|
||||
#include "ps/core/communicator/tcp_server.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
|
@ -45,6 +43,7 @@ class WorkerNode : public AbstractNode {
|
|||
|
||||
private:
|
||||
void Initialize();
|
||||
void CreateTcpServer();
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -52,7 +52,7 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
|
||||
scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
|
||||
scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10);
|
||||
cluster_config_.Init(worker_num_, server_num_, scheduler_host_, scheduler_port_);
|
||||
cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
is_worker_ = false;
|
||||
|
@ -312,6 +312,11 @@ bool PSContext::enable_ssl() const { return enable_ssl_; }
|
|||
|
||||
void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
|
||||
|
||||
core::ClusterConfig &PSContext::cluster_config() { return cluster_config_; }
|
||||
core::ClusterConfig &PSContext::cluster_config() {
|
||||
if (cluster_config_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The cluster config is empty.";
|
||||
}
|
||||
return *cluster_config_;
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -176,7 +176,8 @@ class PSContext {
|
|||
client_epoch_num_(25),
|
||||
client_batch_size_(32),
|
||||
client_learning_rate_(0.001),
|
||||
secure_aggregation_(false) {}
|
||||
secure_aggregation_(false),
|
||||
cluster_config_(nullptr) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
|
@ -234,9 +235,8 @@ class PSContext {
|
|||
bool secure_aggregation_;
|
||||
|
||||
// The cluster config read through environment variables, the value does not change.
|
||||
core::ClusterConfig cluster_config_;
|
||||
std::unique_ptr<core::ClusterConfig> cluster_config_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CONTEXT_H_
|
||||
|
|
|
@ -127,7 +127,7 @@ bool Server::InitCommunicatorWithServer() {
|
|||
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::CLUSTER_TIMEOUT, [&]() {
|
||||
tcp_comm->RegisterEventCallback(core::NodeEvent::CLUSTER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
|
||||
"started during network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
|
@ -135,14 +135,14 @@ bool Server::InitCommunicatorWithServer() {
|
|||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::SCHEDULER_TIMEOUT, [&]() {
|
||||
tcp_comm->RegisterEventCallback(core::NodeEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::NODE_TIMEOUT, [&]() {
|
||||
tcp_comm->RegisterEventCallback(core::NodeEvent::NODE_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR)
|
||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
"network building phase.";
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "common/common_test.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "ps/core/scheduler_node.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -31,7 +32,16 @@ class TestClusterAvailableTimeout : public UT::Common {
|
|||
};
|
||||
|
||||
TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) {
|
||||
PSContext::instance()->cluster_config().Init(1, 1, "127.0.0.1", 9999);
|
||||
std::string worker_num = "1";
|
||||
std::string server_num = "1";
|
||||
std::string host = "127.0.0.1";
|
||||
std::string port = "9999";
|
||||
common::SetEnv(kEnvWorkerNum, worker_num.c_str());
|
||||
common::SetEnv(kEnvPServerNum, server_num.c_str());
|
||||
common::SetEnv(kEnvSchedulerHost, host.c_str());
|
||||
common::SetEnv(kEnvSchedulerPort, port.c_str());
|
||||
PSContext::instance()->SetPSEnable(true);
|
||||
PSContext::instance()->cluster_config().cluster_available_timeout = 3;
|
||||
MS_LOG(INFO) << "The timeout is:" << PSContext::instance()->cluster_config().cluster_available_timeout;
|
||||
SchedulerNode node;
|
||||
}
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
#include <string>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -29,14 +29,21 @@ class TestClusterConfig : public UT::Common {
|
|||
public:
|
||||
TestClusterConfig() = default;
|
||||
virtual ~TestClusterConfig() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestClusterConfig, HeartbeatInterval) {
|
||||
PSContext::instance()->cluster_config().Init(2, 2, "127.0.0.1", 8080);
|
||||
PSContext::instance()->cluster_config().heartbeat_interval = 100;
|
||||
std::string worker_num = "1";
|
||||
std::string server_num = "1";
|
||||
std::string host = "127.0.0.1";
|
||||
std::string port = "9999";
|
||||
common::SetEnv(kEnvWorkerNum, worker_num.c_str());
|
||||
common::SetEnv(kEnvPServerNum, server_num.c_str());
|
||||
common::SetEnv(kEnvSchedulerHost, host.c_str());
|
||||
common::SetEnv(kEnvSchedulerPort, port.c_str());
|
||||
PSContext::instance()->SetPSEnable(true);
|
||||
EXPECT_EQ(300, PSContext::instance()->cluster_config().cluster_available_timeout);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -52,14 +52,6 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
|
|||
EXPECT_TRUE(!ip.empty());
|
||||
}
|
||||
|
||||
TEST_F(TestCommUtil, ValidateRankId) {
|
||||
PSContext::instance()->cluster_config().Init(3, 2, "127.0.0.1", 9999);
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2));
|
||||
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3));
|
||||
EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1));
|
||||
EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2));
|
||||
}
|
||||
|
||||
TEST_F(TestCommUtil, Retry) {
|
||||
bool const ret = CommUtil::Retry([]() -> bool { return false; }, 5, 100);
|
||||
EXPECT_FALSE(ret);
|
||||
|
|
Loading…
Reference in New Issue