Sync enterprise bug fix
This commit is contained in:
parent
86e45f6106
commit
4dc9dc06bb
|
@ -35,7 +35,7 @@ bool CommunicatorBase::SendResponse(const void *rsp_data, size_t rsp_len, std::s
|
|||
}
|
||||
void CommunicatorBase::Join() {
|
||||
if (!running_thread_.joinable()) {
|
||||
MS_LOG(WARNING) << "The running thread of communicator is not joinable.";
|
||||
MS_LOG(INFO) << "The running thread of communicator is already joined.";
|
||||
return;
|
||||
}
|
||||
running_thread_.join();
|
||||
|
|
|
@ -32,8 +32,7 @@ namespace ps {
|
|||
namespace core {
|
||||
class HttpCommunicator : public CommunicatorBase {
|
||||
public:
|
||||
explicit HttpCommunicator(const std::string &ip, std::int16_t port,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor)
|
||||
explicit HttpCommunicator(const std::string &ip, uint16_t port, const std::shared_ptr<TaskExecutor> &task_executor)
|
||||
: task_executor_(task_executor), http_server_(nullptr), ip_(ip), port_(port) {
|
||||
http_server_ = std::make_shared<HttpServer>(ip_, port_, kThreadNum);
|
||||
}
|
||||
|
@ -50,7 +49,7 @@ class HttpCommunicator : public CommunicatorBase {
|
|||
std::unordered_map<std::string, HttpMsgCallback> http_msg_callbacks_;
|
||||
|
||||
std::string ip_;
|
||||
std::int16_t port_;
|
||||
uint16_t port_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -20,45 +20,68 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_(NodeScaleState::kNormal) {
|
||||
FollowerScaler::FollowerScaler(AbstractNode *node)
|
||||
: node_(node), scaling_state_(NodeScaleState::kNormal), running_(true) {
|
||||
process_before_scale_out_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scale_out_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kPreparing; });
|
||||
scale_out_cv_.wait(
|
||||
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
|
||||
if (!running_.load()) {
|
||||
break;
|
||||
}
|
||||
ProcessBeforeScaleOut();
|
||||
}
|
||||
});
|
||||
process_before_scale_out_thread_.detach();
|
||||
|
||||
process_before_scale_in_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> lock(scale_in_mtx_);
|
||||
scale_in_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kPreparing; });
|
||||
scale_in_cv_.wait(
|
||||
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kPreparing; });
|
||||
// In scaling in scenario, abstract node will trigger CLUSTER_SCALE_IN_DONE event in the same thread if this node
|
||||
// is the one to be scaled in, so we need to release the lock here to avoid dead lock.
|
||||
lock.unlock();
|
||||
if (!running_.load()) {
|
||||
break;
|
||||
}
|
||||
ProcessBeforeScaleIn();
|
||||
}
|
||||
});
|
||||
process_before_scale_in_thread_.detach();
|
||||
|
||||
process_after_scale_out_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scale_out_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kScaling; });
|
||||
scale_out_cv_.wait(
|
||||
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
|
||||
if (!running_.load()) {
|
||||
break;
|
||||
}
|
||||
ProcessAfterScaleOut();
|
||||
}
|
||||
});
|
||||
process_after_scale_out_thread_.detach();
|
||||
|
||||
process_after_scale_in_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> lock(scale_in_mtx_);
|
||||
scale_in_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kScaling; });
|
||||
scale_in_cv_.wait(
|
||||
lock, [&]() -> bool { return !running_.load() || scaling_state_.load() == NodeScaleState::kScaling; });
|
||||
if (!running_.load()) {
|
||||
break;
|
||||
}
|
||||
ProcessAfterScaleIn();
|
||||
}
|
||||
});
|
||||
process_after_scale_in_thread_.detach();
|
||||
}
|
||||
|
||||
FollowerScaler::~FollowerScaler() {
|
||||
running_ = false;
|
||||
scale_out_cv_.notify_all();
|
||||
scale_in_cv_.notify_all();
|
||||
process_before_scale_out_thread_.join();
|
||||
process_before_scale_in_thread_.join();
|
||||
process_after_scale_out_thread_.join();
|
||||
process_after_scale_in_thread_.join();
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterScaleEventCallbacks() {
|
||||
|
|
|
@ -56,7 +56,7 @@ enum class NodeScaleState {
|
|||
class FollowerScaler {
|
||||
public:
|
||||
explicit FollowerScaler(AbstractNode *node);
|
||||
~FollowerScaler() = default;
|
||||
~FollowerScaler();
|
||||
|
||||
// The methods called after the events READY_FOR_SCALE_OUT/READY_FOR_SCALE_IN are triggered.
|
||||
void ProcessBeforeScaleOut();
|
||||
|
@ -82,6 +82,7 @@ class FollowerScaler {
|
|||
// Callbacks for scaling events should not be blocked so we notify a thread to call
|
||||
// barriers(barriers_before_scale_out_/barriers_before_scale_in_) or
|
||||
// handlers(handlers_after_scale_out_/handlers_after_scale_in_).
|
||||
std::atomic_bool running_;
|
||||
std::thread process_before_scale_out_thread_;
|
||||
std::thread process_before_scale_in_thread_;
|
||||
std::thread process_after_scale_out_thread_;
|
||||
|
|
|
@ -135,7 +135,7 @@ void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn,
|
|||
server_->SendMessage(conn, meta, Protos::RAW, data, size);
|
||||
}
|
||||
|
||||
std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, std::int16_t port,
|
||||
std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, uint16_t port,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor) {
|
||||
std::lock_guard<std::mutex> lock(communicator_mutex_);
|
||||
if (!communicators_.count(kHttpCommunicator)) {
|
||||
|
|
|
@ -57,7 +57,7 @@ class ServerNode : public AbstractNode {
|
|||
void set_handler(const RequestHandler &handler);
|
||||
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
|
||||
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, std::int16_t port,
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor);
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port,
|
||||
uint32_t worker_num, uint32_t server_num,
|
||||
|
|
|
@ -53,7 +53,7 @@ void OptimizerInfo::UpdateOptimInputValue(const std::string &optim_type, const s
|
|||
|
||||
size_t origin_index = origin_input_map.at(input_name);
|
||||
size_t ps_send_index = ps_send_index_map.at(input_name);
|
||||
if (ps_send_index > lens.size() || origin_index > inputs_.size()) {
|
||||
if (ps_send_index >= lens.size() || origin_index >= inputs_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index is out of bound for optimizer " << optim_type << ", origin_index:" << origin_index
|
||||
<< ", ps_send_index:" << ps_send_index;
|
||||
}
|
||||
|
@ -96,6 +96,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
|
|||
void DenseOptimInfo::ComputeMean(const std::vector<std::vector<size_t>> &, size_t n, size_t, size_t) {
|
||||
if (n > 1) {
|
||||
float *accum_grad_data = reinterpret_cast<float *>(gradient()->addr);
|
||||
MS_EXCEPTION_IF_NULL(accum_grad_data);
|
||||
size_t size = gradient()->size / sizeof(float);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
accum_grad_data[i] /= n;
|
||||
|
@ -136,7 +137,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) {
|
|||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return;
|
||||
}
|
||||
grads_offset_ += lengths[grad_index];
|
||||
grads_offset_ += IntToSize(lengths[grad_index]);
|
||||
gradient()->size += incr_grad_size;
|
||||
|
||||
// Append indice data to the end
|
||||
|
|
|
@ -178,6 +178,7 @@ void PSContext::set_server_mode(const std::string &server_mode) {
|
|||
<< " or " << kServerModeHybrid;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Server mode: " << server_mode_ << " is used for Server and Worker. Scheduler will ignore it.";
|
||||
server_mode_ = server_mode;
|
||||
}
|
||||
|
||||
|
@ -198,7 +199,11 @@ void PSContext::set_ms_role(const std::string &role) {
|
|||
void PSContext::set_worker_num(uint32_t worker_num) {
|
||||
// Hybrid training mode only supports one worker for now.
|
||||
if (server_mode_ == kServerModeHybrid && worker_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "The worker number should be set to 1 for now in hybrid training mode.";
|
||||
MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
|
||||
return;
|
||||
}
|
||||
if (server_mode_ == kServerModeFL && worker_num != 0) {
|
||||
MS_LOG(EXCEPTION) << "The worker number should be 0 in federated learning mode.";
|
||||
return;
|
||||
}
|
||||
worker_num_ = worker_num;
|
||||
|
|
|
@ -36,7 +36,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
}
|
||||
|
||||
uint32_t rank_size = server_num_;
|
||||
uint32_t local_rank_ = server_node_->rank_id();
|
||||
size_t chunk_size = count / rank_size;
|
||||
size_t remainder_size = count % rank_size;
|
||||
std::vector<size_t> chunk_sizes(rank_size, chunk_size);
|
||||
|
@ -129,7 +128,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
template <typename T>
|
||||
bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) {
|
||||
uint32_t rank_size = server_num_;
|
||||
uint32_t local_rank_ = server_node_->rank_id();
|
||||
MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_
|
||||
<< ", count:" << count;
|
||||
int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T));
|
||||
|
|
|
@ -29,7 +29,7 @@ class Server;
|
|||
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
communicator_->RegisterMsgCallBack("syncIteraion",
|
||||
communicator_->RegisterMsgCallBack("syncIteration",
|
||||
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"notifyLeaderToNextIter",
|
||||
|
@ -114,7 +114,6 @@ void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &
|
|||
|
||||
void Iteration::SetIterationRunning() {
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
if (server_node_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Server node is empty.";
|
||||
return;
|
||||
|
@ -123,11 +122,11 @@ void Iteration::SetIterationRunning() {
|
|||
// This event helps worker/server to be consistent in iteration state.
|
||||
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning));
|
||||
}
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
}
|
||||
|
||||
void Iteration::SetIterationCompleted() {
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
|
||||
iteration_state_ = IterationState::kCompleted;
|
||||
if (server_node_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Server node is empty.";
|
||||
return;
|
||||
|
@ -136,6 +135,7 @@ void Iteration::SetIterationCompleted() {
|
|||
// This event helps worker/server to be consistent in iteration state.
|
||||
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted));
|
||||
}
|
||||
iteration_state_ = IterationState::kCompleted;
|
||||
}
|
||||
|
||||
void Iteration::ScalingBarrier() {
|
||||
|
@ -147,18 +147,18 @@ void Iteration::ScalingBarrier() {
|
|||
}
|
||||
|
||||
bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
|
||||
if (server_rank != kLeaderServerRank) {
|
||||
if (!SyncIteration(server_rank)) {
|
||||
MS_LOG(ERROR) << "Synchronizing iteration failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (auto &round : rounds_) {
|
||||
if (!round->ReInitForScaling(server_num)) {
|
||||
MS_LOG(WARNING) << "Reinitializing round " << round->name() << " for scaling failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (server_rank != kLeaderServerRank) {
|
||||
if (!SyncIteration(server_rank)) {
|
||||
MS_LOG(ERROR) << "Synchronizing iteration failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -171,8 +171,8 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
sync_iter_req.set_rank(rank);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
|
||||
if (communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
|
||||
&sync_iter_rsp_msg)) {
|
||||
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
|
||||
&sync_iter_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -183,6 +183,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
|
||||
SyncIterationResponse sync_iter_rsp;
|
||||
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
|
||||
iteration_num_ = sync_iter_rsp.iteration();
|
||||
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
|
||||
<< sync_iter_rsp.iteration();
|
||||
return true;
|
||||
|
@ -207,7 +208,7 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHa
|
|||
|
||||
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
|
||||
std::unique_lock<std::mutex> lock(pinned_mtx_);
|
||||
if (pinned_iter_num_ >= iteration_num) {
|
||||
if (pinned_iter_num_ == iteration_num) {
|
||||
MS_LOG(WARNING) << "MoveToNextIteration is not reentrant. Ignore this call.";
|
||||
return true;
|
||||
}
|
||||
|
@ -241,7 +242,7 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<co
|
|||
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
|
||||
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
|
||||
notify_leader_to_next_iter_req.ParseFromArray(message->data(), message->len());
|
||||
notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &rank = notify_leader_to_next_iter_req.rank();
|
||||
const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
|
||||
const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
|
||||
|
@ -349,7 +350,7 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageH
|
|||
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
|
||||
MoveToNextIterRequest proceed_to_next_iter_req;
|
||||
proceed_to_next_iter_req.ParseFromArray(message->data(), message->len());
|
||||
proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
|
||||
const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
|
||||
const auto &reason = proceed_to_next_iter_req.reason();
|
||||
|
@ -403,7 +404,7 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHand
|
|||
}
|
||||
|
||||
EndLastIterRequest end_last_iter_req;
|
||||
end_last_iter_req.ParseFromArray(message->data(), message->len());
|
||||
end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &last_iter_num = end_last_iter_req.last_iter_num();
|
||||
// If the iteration number is not matched, return error.
|
||||
if (last_iter_num != iteration_num_) {
|
||||
|
|
|
@ -27,9 +27,9 @@ namespace mindspore {
|
|||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") {
|
||||
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
|
||||
release_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
while (running_.load()) {
|
||||
std::unique_lock<std::mutex> release_lock(release_mtx_);
|
||||
// Detect whether there's any data needs to be released every 100 milliseconds.
|
||||
if (heap_data_to_release_.empty()) {
|
||||
|
@ -52,7 +52,13 @@ RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), e
|
|||
heap_data_.erase(heap_data_.find(addr_ptr));
|
||||
}
|
||||
});
|
||||
release_thread_.detach();
|
||||
}
|
||||
|
||||
RoundKernel::~RoundKernel() {
|
||||
running_ = false;
|
||||
if (release_thread_.joinable()) {
|
||||
release_thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace kernel {
|
|||
class RoundKernel : virtual public CPUKernel {
|
||||
public:
|
||||
RoundKernel();
|
||||
virtual ~RoundKernel() = default;
|
||||
virtual ~RoundKernel();
|
||||
|
||||
// RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
|
||||
// inherited method is empty.
|
||||
|
@ -112,6 +112,7 @@ class RoundKernel : virtual public CPUKernel {
|
|||
|
||||
// To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
|
||||
// be threadsafe.
|
||||
std::atomic_bool running_;
|
||||
std::thread release_thread_;
|
||||
|
||||
// Data needs to be released and its mutex;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/model_store.h"
|
||||
#include "ps/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -171,7 +172,11 @@ void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const D
|
|||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, AddressPtr> feature_maps = executor_->GetModel();
|
||||
size_t last_iteration = LocalMetaStore::GetInstance().curr_iter_num() - 1;
|
||||
auto feature_maps = ModelStore::GetInstance().GetModelByIterNum(last_iteration);
|
||||
if (feature_maps.empty()) {
|
||||
MS_LOG(WARNING) << "The feature map for startFLJob is empty.";
|
||||
}
|
||||
BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
||||
feature_maps);
|
||||
|
|
|
@ -116,7 +116,7 @@ class Server {
|
|||
// Which protocol should communicators use.
|
||||
bool use_tcp_;
|
||||
bool use_http_;
|
||||
uint64_t http_port_;
|
||||
uint16_t http_port_;
|
||||
|
||||
// The configure of all rounds.
|
||||
std::vector<RoundConfig> rounds_config_;
|
||||
|
|
Loading…
Reference in New Issue