Synchronize enter code

This commit is contained in:
chendongsheng 2021-06-26 16:58:53 +08:00
parent aab9502549
commit f9105a7616
15 changed files with 98 additions and 56 deletions

View File

@ -86,6 +86,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &mess
void AbstractNode::set_ready_for_scale_out() {
MS_LOG(INFO) << "[Scale out]: begin to set ready for scale out.";
Register(client_to_scheduler_);
std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear();
}
@ -93,6 +94,7 @@ void AbstractNode::set_ready_for_scale_in() {
MS_LOG(INFO) << "[Scale in]: begin to set ready for scale in.";
if (!is_current_node_scale_in_) {
Register(client_to_scheduler_);
std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear();
}
}
@ -107,8 +109,9 @@ void AbstractNode::set_scale_out_done() {
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
scale_out_done_message.SerializeAsString().data(), scale_out_done_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " scale_out_done timeout!";
return;
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -125,8 +128,9 @@ void AbstractNode::set_scale_in_done() {
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF,
scale_in_done_message.SerializeAsString().data(), scale_in_done_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
MS_LOG(WARNING) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " scale_in_done timeout!";
return;
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -325,10 +329,15 @@ std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum No
receive_callbacks_[std::make_pair(rank_id, rank_request_id)] = [=]() mutable {
receive_callbacks_mutex_.lock();
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
*output = res;
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
if (*output != nullptr) {
MS_LOG(WARNING) << "The output is not empty.";
} else {
*output = res;
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
}
receive_callbacks_mutex_.unlock();
};
}
@ -474,18 +483,22 @@ void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const
<< ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
}
bool is_worker_or_server0 = heartbeat_resp_message.is_worker_or_server0();
if (current_cluster_state_ == ClusterState::CLUSTER_READY) {
is_ready_ = true;
wait_start_cond_.notify_all();
}
if (node_recovery_ == nullptr) {
MS_LOG(INFO) << "The recovery is disable.";
if (node_recovery_ == nullptr || is_worker_or_server0) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(INFO) << "The recovery is disable.";
is_ready_ = true;
wait_start_cond_.notify_all();
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
}
} else {
MS_LOG(INFO) << "The node is support recovery, users can pull up this node to restore the cluster.";
}
}
@ -535,7 +548,8 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
node_info_.rank_id_ = send_meta_message.rank_id();
current_cluster_state_ = send_meta_message.cluster_state();
MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
<< ", cluster state is:" << current_cluster_state_ << ", the rank id:" << node_info_.rank_id_;
<< ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
<< ", the rank id:" << node_info_.rank_id_;
nodes_address_.clear();
for (const auto &it : send_meta_message.servers_meta()) {
@ -556,7 +570,9 @@ void AbstractNode::ProcessSendMetadata(std::shared_ptr<TcpConnection> conn, std:
OnEventCallback(ClusterEvent::CLUSTER_SCALE_IN_DONE);
}
MS_LOG(INFO) << "The current cluster state:" << current_cluster_state_;
std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear();
MS_LOG(INFO) << "The current cluster state:" << CommUtil::ClusterStateToString(current_cluster_state_);
}
void AbstractNode::ProcessFinish(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,

View File

@ -51,7 +51,7 @@ class AbstractNode : public Node {
node_recovery_(nullptr),
scheduler_ip_(""),
scheduler_port_(0) {}
~AbstractNode() override = default;
~AbstractNode() override { is_finish_ = true; }
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
typedef void (AbstractNode::*ServerHandler)(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,

View File

@ -49,7 +49,7 @@ bool CommUtil::CheckIp(const std::string &ip) {
if (!CheckIpWithRegex(ip)) {
return false;
}
int64_t uAddr = inet_addr(ip.c_str());
uint32_t uAddr = inet_addr(ip.c_str());
if (INADDR_NONE == uAddr) {
return false;
}

View File

@ -135,7 +135,10 @@ bool HttpServer::Start(bool is_detach) {
MS_LOG(INFO) << "Start http server!";
for (size_t i = 0; i < thread_num_; i++) {
auto http_request_handler = std::make_shared<HttpRequestHandler>();
http_request_handler->Initialize(fd_, request_handlers_);
if (!http_request_handler->Initialize(fd_, request_handlers_)) {
MS_LOG(ERROR) << "Http initialize failed.";
return false;
}
http_request_handlers.push_back(http_request_handler);
auto thread = std::make_shared<std::thread>(&HttpRequestHandler::Run, http_request_handler);
if (is_detach) {
@ -147,7 +150,7 @@ bool HttpServer::Start(bool is_detach) {
}
bool HttpServer::Wait() {
for (size_t i = 0; i < thread_num_; i++) {
for (size_t i = 0; i < worker_threads_.size(); i++) {
worker_threads_[i]->join();
worker_threads_[i].reset();
}
@ -159,7 +162,7 @@ bool HttpServer::Stop() {
bool result = true;
if (!is_stop_.load()) {
for (size_t i = 0; i < thread_num_; i++) {
for (size_t i = 0; i < http_request_handlers.size(); i++) {
bool res = http_request_handlers[i]->Stop();
if (res == false) {
result = false;

View File

@ -84,7 +84,7 @@ SSL_CTX *SSLWrapper::GetSSLCtx(bool is_server) {
}
}
X509 *SSLWrapper::ReadCertFromFile(const std::string &certPath) {
X509 *SSLWrapper::ReadCertFromFile(const std::string &certPath) const {
BIO *bio = BIO_new_file(certPath.c_str(), "r");
return PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
}
@ -94,30 +94,12 @@ X509 *SSLWrapper::ReadCertFromPerm(std::string cert) {
return PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
}
X509_CRL *SSLWrapper::ReadCrlFromFile(const std::string &crlPath) {
X509_CRL *SSLWrapper::ReadCrlFromFile(const std::string &crlPath) const {
BIO *bio = BIO_new_file(crlPath.c_str(), "r");
return PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
}
void SSLWrapper::InitRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
const std::string crlFirstFilePath, const std::string crlSecondFilePath) {
if (rootFirstCaFilePath.empty() || rootSecondCaFilePath.empty() || crlFirstFilePath.empty() ||
crlSecondFilePath.empty()) {
return;
}
rootFirstCA_ = SSLWrapper::ReadCertFromFile(rootFirstCaFilePath);
rootSecondCA_ = SSLWrapper::ReadCertFromFile(rootSecondCaFilePath);
MS_LOG(INFO) << "Root first ca serialNumber: " << X509_get_serialNumber(rootFirstCA_)->data;
MS_LOG(INFO) << "Root second ca serialNumber: " << X509_get_serialNumber(rootSecondCA_)->data;
rootFirstCrl_ = SSLWrapper::ReadCrlFromFile(crlFirstFilePath);
rootSecondCrl_ = SSLWrapper::ReadCrlFromFile(crlSecondFilePath);
MS_LOG(INFO) << "Root first crl version: " << X509_CRL_get_version(rootFirstCrl_);
MS_LOG(INFO) << "Root second crl version: " << X509_CRL_get_version(rootSecondCrl_);
}
bool SSLWrapper::VerifyCertTime(const X509 *cert) {
bool SSLWrapper::VerifyCertTime(const X509 *cert) const {
ASN1_TIME *start = X509_getm_notBefore(cert);
ASN1_TIME *end = X509_getm_notAfter(cert);

View File

@ -39,20 +39,17 @@ class SSLWrapper {
}
SSL_CTX *GetSSLCtx(bool is_server = true);
void InitRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
const std::string crlFirstFilePath, const std::string crlSecondFilePath);
// read certificate from file path
X509 *ReadCertFromFile(const std::string &certPath);
X509 *ReadCertFromFile(const std::string &certPath) const;
// read Certificate Revocation List from file absolute path
X509_CRL *ReadCrlFromFile(const std::string &crlPath);
X509_CRL *ReadCrlFromFile(const std::string &crlPath) const;
// read certificate from pem string
X509 *ReadCertFromPerm(std::string cert);
// verify valid of certificate time
bool VerifyCertTime(const X509 *cert);
bool VerifyCertTime(const X509 *cert) const;
// verify valid of certificate chain
bool VerifyCAChain(const std::string &keyAttestation, const std::string &equipCert, const std::string &equipCACert,

View File

@ -57,7 +57,7 @@ TaskExecutor::TaskExecutor(size_t thread_num, size_t max_task_num, size_t submit
}
notify_thread_ = std::thread([this]() {
// If there is no idle thread, wait until the working thread is available.
while (true) {
while (running_) {
{
std::unique_lock<std::mutex> lock(mtx_);
if (idle_thread_num_ > 0 && task_num_ > 0) {
@ -70,7 +70,6 @@ TaskExecutor::TaskExecutor(size_t thread_num, size_t max_task_num, size_t submit
std::this_thread::sleep_for(std::chrono::milliseconds(kSubmitTaskIntervalInMs));
}
});
notify_thread_.detach();
}
TaskExecutor::~TaskExecutor() {
@ -82,6 +81,7 @@ TaskExecutor::~TaskExecutor() {
for (auto &t : working_threads_) {
t.join();
}
notify_thread_.join();
}
} // namespace core
} // namespace ps

View File

@ -37,10 +37,19 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
--num;
if (header_index_ == kHeaderLen - 1) {
message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_);
if (message_header_.message_proto_ != Protos::RAW && message_header_.message_proto_ != Protos::FLATBUFFERS &&
message_header_.message_proto_ != Protos::PROTOBUF) {
MS_LOG(WARNING) << "The proto:" << message_header_.message_proto_ << " is illegal!";
return;
}
message_header_.message_meta_length_ =
*reinterpret_cast<const uint32_t *>(header_ + sizeof(message_header_.message_proto_));
message_header_.message_length_ = *reinterpret_cast<const size_t *>(
header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_));
if (message_header_.message_length_ >= UINT32_MAX) {
MS_LOG(WARNING) << "The message len:" << message_header_.message_length_ << " is too long.";
return;
}
remaining_length_ = message_header_.message_length_;
message_buffer_ = std::make_unique<unsigned char[]>(remaining_length_);
buffer_data += (i + 1);

View File

@ -359,7 +359,7 @@ void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
struct timeval delay = {0, 0};
MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
if (event_base_loopexit(base, &delay) == -1) {
MS_LOG(EXCEPTION) << "Event base loop exit failed.";
MS_LOG(ERROR) << "Event base loop exit failed.";
}
}

View File

@ -42,7 +42,7 @@ namespace mindspore {
namespace ps {
namespace core {
constexpr int kTimeoutInSeconds = 30;
constexpr int kCommTimeoutInSeconds = 3;
constexpr int kCommTimeoutInSeconds = 30;
class Node {
public:
Node()

View File

@ -244,6 +244,22 @@ void NodeManager::ResetMetadata() {
next_server_rank_id_ = -1;
}
bool NodeManager::IsWorkerOrServer0() {
bool res = std::any_of(registered_nodes_info_.begin(), registered_nodes_info_.end(), [](auto item) {
if (item.second.node_role_ == NodeRole::WORKER && item.second.is_alive == false) {
return true;
}
if (item.second.node_role_ == NodeRole::SERVER && item.second.is_alive == false && item.second.rank_id_ == 0) {
return true;
}
return false;
});
return res;
}
void NodeManager::set_total_node_num(const int32_t &node_num) { total_node_num_ = node_num; }
const int32_t &NodeManager::total_node_num() { return total_node_num_; }

View File

@ -108,6 +108,9 @@ class NodeManager {
// will re-register.
void ResetMetadata();
// Recovery currently does not support worker or server0 node downtime.
bool IsWorkerOrServer0();
private:
std::mutex node_mutex_;
std::mutex cluster_mutex_;

View File

@ -98,6 +98,7 @@ enum ClusterState {
message HeartbeatRespMessage {
ClusterState cluster_state = 1;
repeated ServersMeta servers_meta = 2;
bool is_worker_or_server0 = 3;
}
message FetchServersMessage {

View File

@ -63,6 +63,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
*heartbeat_resp_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};
heartbeat_resp_message.set_is_worker_or_server0(node_manager_.IsWorkerOrServer0());
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
heartbeat_resp_message.ByteSizeLong());
}
@ -137,7 +139,8 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
if (node_manager_.IsAllNodesRegistered()) {
is_ready_ = true;
MS_LOG(INFO) << "All nodes is registered, scheduler send meta data to worker/server.";
MS_LOG(INFO) << "There are " << node_manager_.worker_num() << " workers and " << node_manager_.server_num()
<< " servers registered to scheduer, so the scheduler send meta data to worker/server.";
if (node_manager_.GetClusterState() == ClusterState::CLUSTER_SCALE_IN) {
auto nodes = node_manager_.nodes_info();
for (const auto &id : scale_in_node_ids_) {
@ -152,6 +155,7 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
auto client = GetOrCreateClient(kvs.second);
SendMetadata(client, kvs.second.rank_id_);
}
node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY);
wait_start_cond_.notify_all();
}
}
@ -253,7 +257,7 @@ void SchedulerNode::ProcessSendEvent(std::shared_ptr<TcpServer> server, std::sha
event_message.ParseFromArray(data, size);
std::string node_id = event_message.node_id();
uint32_t event = event_message.event();
MS_LOG(INFO) << "The scheduler process a event message from node id:" << node_id;
MS_LOG(DEBUG) << "The scheduler process a event message from node id:" << node_id;
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
@ -354,8 +358,8 @@ void SchedulerNode::SendEvent(const std::shared_ptr<TcpClient> &client, const ui
return;
}
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is sending event resp to workers and servers!";
MS_LOG(DEBUG) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is sending event resp to workers and servers!";
}
void SchedulerNode::StartUpdateClusterStateTimer() {
@ -415,7 +419,9 @@ bool SchedulerNode::Stop() {
connected_node.second->Stop();
}
}
client_thread_->join();
if (client_thread_ != nullptr && client_thread_->joinable()) {
client_thread_->join();
}
is_ready_ = true;
}
if (PSContext::instance()->scheduler_manage_port() != 0) {
@ -426,7 +432,7 @@ bool SchedulerNode::Stop() {
return true;
}
bool SchedulerNode::Finish(const uint32_t &timeout) {
bool SchedulerNode::Finish(const uint32_t &) {
MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!";
std::unique_lock<std::mutex> lock(wait_finish_mutex_);
wait_finish_cond_.wait(lock, [&] {
@ -469,6 +475,9 @@ void SchedulerNode::ProcessScaleOut(std::shared_ptr<HttpMessageHandler> resp) {
int32_t total_worker_num = scale_worker_num + node_manager_.worker_num();
int32_t total_server_num = scale_server_num + node_manager_.server_num();
MS_LOG(INFO) << "After scale out, the total worker num:" << total_worker_num
<< ", the total server num:" << total_server_num;
node_manager_.set_worker_num(total_worker_num);
node_manager_.set_server_num(total_server_num);
node_manager_.set_total_node_num(total_worker_num + total_server_num);
@ -704,7 +713,7 @@ void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t
void SchedulerNode::StopRestfulServer() {
MS_LOG(INFO) << "Scheduler stop https server.";
http_server_->Stop();
if (restful_thread_->joinable()) {
if (restful_thread_ != nullptr && restful_thread_->joinable()) {
restful_thread_->join();
}
}

View File

@ -41,7 +41,6 @@ bool WorkerNode::Start(const uint32_t &timeout) {
}
void WorkerNode::Initialize() {
is_already_stopped_ = false;
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty, then init node by context.";
@ -61,6 +60,7 @@ void WorkerNode::Initialize() {
if (!InitClientToScheduler()) {
MS_LOG(EXCEPTION) << "Worker node connect to scheduler timeout!";
}
is_already_stopped_ = false;
MS_LOG(INFO) << "[Worker start]: 3. Worker node crete tcp client to scheduler successful!";
}
@ -109,6 +109,12 @@ bool WorkerNode::Finish(const uint32_t &timeout) {
}
MS_LOG(INFO) << "[Worker finish]: 1. Begin to finish worker node!";
is_already_finished_ = true;
if (is_already_stopped_) {
MS_LOG(INFO) << "The node is already stop.";
return true;
}
bool res = Disconnect(client_to_scheduler_, timeout);
if (res) {
MS_LOG(INFO) << "[Worker finish]: 2. Successfully finish worker node!";