Sync enterprise bug fix

This commit is contained in:
ZPaC 2021-06-28 16:14:29 +08:00
parent 86e45f6106
commit 4dc9dc06bb
14 changed files with 86 additions and 46 deletions

View File

@ -35,7 +35,7 @@ bool CommunicatorBase::SendResponse(const void *rsp_data, size_t rsp_len, std::s
}
void CommunicatorBase::Join() {
if (!running_thread_.joinable()) {
MS_LOG(WARNING) << "The running thread of communicator is not joinable.";
MS_LOG(INFO) << "The running thread of communicator is already joined.";
return;
}
running_thread_.join();

View File

@ -32,8 +32,7 @@ namespace ps {
namespace core {
class HttpCommunicator : public CommunicatorBase {
public:
explicit HttpCommunicator(const std::string &ip, std::int16_t port,
const std::shared_ptr<TaskExecutor> &task_executor)
explicit HttpCommunicator(const std::string &ip, uint16_t port, const std::shared_ptr<TaskExecutor> &task_executor)
: task_executor_(task_executor), http_server_(nullptr), ip_(ip), port_(port) {
http_server_ = std::make_shared<HttpServer>(ip_, port_, kThreadNum);
}
@ -50,7 +49,7 @@ class HttpCommunicator : public CommunicatorBase {
std::unordered_map<std::string, HttpMsgCallback> http_msg_callbacks_;
std::string ip_;
std::int16_t port_;
uint16_t port_;
};
} // namespace core
} // namespace ps

View File

@ -20,45 +20,68 @@
namespace mindspore {
namespace ps {
namespace core {
FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_(NodeScaleState::kNormal) {
FollowerScaler::FollowerScaler(AbstractNode *node)
: node_(node), scaling_state_(NodeScaleState::kNormal), running_(true) {
process_before_scale_out_thread_ = std::thread([&]() {
while (true) {
while (running_.load()) {
std::unique_lock<std::mutex> lock(scale_out_mtx_);
scale_out_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kPreparing; });
scale_out_cv_.wait(
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
if (!running_.load()) {
break;
}
ProcessBeforeScaleOut();
}
});
process_before_scale_out_thread_.detach();
process_before_scale_in_thread_ = std::thread([&]() {
while (true) {
while (running_.load()) {
std::unique_lock<std::mutex> lock(scale_in_mtx_);
scale_in_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kPreparing; });
scale_in_cv_.wait(
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
// In scaling in scenario, abstract node will trigger CLUSTER_SCALE_IN_DONE event in the same thread if this node
// is the one to be scaled in, so we need to release the lock here to avoid dead lock.
lock.unlock();
if (!running_.load()) {
break;
}
ProcessBeforeScaleIn();
}
});
process_before_scale_in_thread_.detach();
process_after_scale_out_thread_ = std::thread([&]() {
while (true) {
while (running_.load()) {
std::unique_lock<std::mutex> lock(scale_out_mtx_);
scale_out_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kScaling; });
scale_out_cv_.wait(
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
if (!running_.load()) {
break;
}
ProcessAfterScaleOut();
}
});
process_after_scale_out_thread_.detach();
process_after_scale_in_thread_ = std::thread([&]() {
while (true) {
while (running_.load()) {
std::unique_lock<std::mutex> lock(scale_in_mtx_);
scale_in_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kScaling; });
scale_in_cv_.wait(
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
if (!running_.load()) {
break;
}
ProcessAfterScaleIn();
}
});
process_after_scale_in_thread_.detach();
}
FollowerScaler::~FollowerScaler() {
running_ = false;
scale_out_cv_.notify_all();
scale_in_cv_.notify_all();
process_before_scale_out_thread_.join();
process_before_scale_in_thread_.join();
process_after_scale_out_thread_.join();
process_after_scale_in_thread_.join();
}
void FollowerScaler::RegisterScaleEventCallbacks() {

View File

@ -56,7 +56,7 @@ enum class NodeScaleState {
class FollowerScaler {
public:
explicit FollowerScaler(AbstractNode *node);
~FollowerScaler() = default;
~FollowerScaler();
// The methods called after the events READY_FOR_SCALE_OUT/READY_FOR_SCALE_IN are triggered.
void ProcessBeforeScaleOut();
@ -82,6 +82,7 @@ class FollowerScaler {
// Callbacks for scaling events should not be blocked so we notify a thread to call
// barriers(barriers_before_scale_out_/barriers_before_scale_in_) or
// handlers(handlers_after_scale_out_/handlers_after_scale_in_).
std::atomic_bool running_;
std::thread process_before_scale_out_thread_;
std::thread process_before_scale_in_thread_;
std::thread process_after_scale_out_thread_;

View File

@ -135,7 +135,7 @@ void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn,
server_->SendMessage(conn, meta, Protos::RAW, data, size);
}
std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, std::int16_t port,
std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, uint16_t port,
const std::shared_ptr<TaskExecutor> &task_executor) {
std::lock_guard<std::mutex> lock(communicator_mutex_);
if (!communicators_.count(kHttpCommunicator)) {

View File

@ -57,7 +57,7 @@ class ServerNode : public AbstractNode {
void set_handler(const RequestHandler &handler);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, std::int16_t port,
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
const std::shared_ptr<TaskExecutor> &task_executor);
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port,
uint32_t worker_num, uint32_t server_num,

View File

@ -53,7 +53,7 @@ void OptimizerInfo::UpdateOptimInputValue(const std::string &optim_type, const s
size_t origin_index = origin_input_map.at(input_name);
size_t ps_send_index = ps_send_index_map.at(input_name);
if (ps_send_index > lens.size() || origin_index > inputs_.size()) {
if (ps_send_index >= lens.size() || origin_index >= inputs_.size()) {
MS_LOG(EXCEPTION) << "Index is out of bound for optimizer " << optim_type << ", origin_index:" << origin_index
<< ", ps_send_index:" << ps_send_index;
}
@ -96,6 +96,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
void DenseOptimInfo::ComputeMean(const std::vector<std::vector<size_t>> &, size_t n, size_t, size_t) {
if (n > 1) {
float *accum_grad_data = reinterpret_cast<float *>(gradient()->addr);
MS_EXCEPTION_IF_NULL(accum_grad_data);
size_t size = gradient()->size / sizeof(float);
for (size_t i = 0; i < size; i++) {
accum_grad_data[i] /= n;
@ -136,7 +137,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
grads_offset_ += lengths[grad_index];
grads_offset_ += IntToSize(lengths[grad_index]);
gradient()->size += incr_grad_size;
// Append indice data to the end

View File

@ -178,6 +178,7 @@ void PSContext::set_server_mode(const std::string &server_mode) {
<< " or " << kServerModeHybrid;
return;
}
MS_LOG(INFO) << "Server mode: " << server_mode_ << " is used for Server and Worker. Scheduler will ignore it.";
server_mode_ = server_mode;
}
@ -198,7 +199,11 @@ void PSContext::set_ms_role(const std::string &role) {
void PSContext::set_worker_num(uint32_t worker_num) {
// Hybrid training mode only supports one worker for now.
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
MS_LOG(EXCEPTION) << "The worker number should be set to 1 for now in hybrid training mode.";
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
return;
}
if (server_mode_ == kServerModeFL && worker_num != 0) {
MS_LOG(EXCEPTION) << "The worker number should be 0 in federated learning mode.";
return;
}
worker_num_ = worker_num;

View File

@ -36,7 +36,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
}
uint32_t rank_size = server_num_;
uint32_t local_rank_ = server_node_->rank_id();
size_t chunk_size = count / rank_size;
size_t remainder_size = count % rank_size;
std::vector<size_t> chunk_sizes(rank_size, chunk_size);
@ -129,7 +128,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
template <typename T>
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
uint32_t rank_size = server_num_;
uint32_t local_rank_ = server_node_->rank_id();
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
<< ", count:" << count;
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));

View File

@ -29,7 +29,7 @@ class Server;
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
communicator_->RegisterMsgCallBack("syncIteraion",
communicator_->RegisterMsgCallBack("syncIteration",
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack(
"notifyLeaderToNextIter",
@ -114,7 +114,6 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &
void Iteration::SetIterationRunning() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
iteration_state_ = IterationState::kRunning;
if (server_node_ == nullptr) {
MS_LOG(ERROR) << "Server node is empty.";
return;
@ -123,11 +122,11 @@ void Iteration::SetIterationRunning() {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning));
}
iteration_state_ = IterationState::kRunning;
}
void Iteration::SetIterationCompleted() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
iteration_state_ = IterationState::kCompleted;
if (server_node_ == nullptr) {
MS_LOG(ERROR) << "Server node is empty.";
return;
@ -136,6 +135,7 @@ void Iteration::SetIterationCompleted() {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted));
}
iteration_state_ = IterationState::kCompleted;
}
void Iteration::ScalingBarrier() {
@ -147,18 +147,18 @@ void Iteration::ScalingBarrier() {
}
bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
if (server_rank != kLeaderServerRank) {
if (!SyncIteration(server_rank)) {
MS_LOG(ERROR) << "Synchronizing iteration failed.";
return false;
}
}
for (auto &round : rounds_) {
if (!round->ReInitForScaling(server_num)) {
MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed.";
return false;
}
}
if (server_rank != kLeaderServerRank) {
if (!SyncIteration(server_rank)) {
MS_LOG(ERROR) << "Synchronizing iteration failed.";
return false;
}
}
return true;
}
@ -171,8 +171,8 @@ bool Iteration::SyncIteration(uint32_t rank) {
sync_iter_req.set_rank(rank);
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
if (communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
&sync_iter_rsp_msg)) {
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
&sync_iter_rsp_msg)) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
return false;
}
@ -183,6 +183,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
SyncIterationResponse sync_iter_rsp;
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
iteration_num_ = sync_iter_rsp.iteration();
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
<< sync_iter_rsp.iteration();
return true;
@ -207,7 +208,7 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHa
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
std::unique_lock<std::mutex> lock(pinned_mtx_);
if (pinned_iter_num_ >= iteration_num) {
if (pinned_iter_num_ == iteration_num) {
MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
return true;
}
@ -241,7 +242,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<co
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
notify_leader_to_next_iter_req.ParseFromArray(message->data(), message->len());
notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &rank = notify_leader_to_next_iter_req.rank();
const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
@ -349,7 +350,7 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageH
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
MoveToNextIterRequest proceed_to_next_iter_req;
proceed_to_next_iter_req.ParseFromArray(message->data(), message->len());
proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
const auto &reason = proceed_to_next_iter_req.reason();
@ -403,7 +404,7 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHand
}
EndLastIterRequest end_last_iter_req;
end_last_iter_req.ParseFromArray(message->data(), message->len());
end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &last_iter_num = end_last_iter_req.last_iter_num();
// If the iteration number is not matched, return error.
if (last_iter_num != iteration_num_) {

View File

@ -27,9 +27,9 @@ namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") {
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
release_thread_ = std::thread([&]() {
while (true) {
while (running_.load()) {
std::unique_lock<std::mutex> release_lock(release_mtx_);
// Detect whether there's any data needs to be released every 100 milliseconds.
if (heap_data_to_release_.empty()) {
@ -52,7 +52,13 @@ RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), e
heap_data_.erase(heap_data_.find(addr_ptr));
}
});
release_thread_.detach();
}
RoundKernel::~RoundKernel() {
running_ = false;
if (release_thread_.joinable()) {
release_thread_.join();
}
}
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }

View File

@ -47,7 +47,7 @@ namespace kernel {
class RoundKernel : virtual public CPUKernel {
public:
RoundKernel();
virtual ~RoundKernel() = default;
virtual ~RoundKernel();
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
// inherited method is empty.
@ -112,6 +112,7 @@ class RoundKernel : virtual public CPUKernel {
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
// be threadsafe.
std::atomic_bool running_;
std::thread release_thread_;
// Data needs to be released and its mutex;

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include "ps/server/model_store.h"
#include "ps/server/iteration.h"
namespace mindspore {
@ -171,7 +172,11 @@ void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const D
return;
}
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1;
auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
if (feature_maps.empty()) {
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
}
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
feature_maps);

View File

@ -116,7 +116,7 @@ class Server {
// Which protocol should communicators use.
bool use_tcp_;
bool use_http_;
uint64_t http_port_;
uint16_t http_port_;
// The configure of all rounds.
std::vector<RoundConfig> rounds_config_;