forked from mindspore-Ecosystem/mindspore
!17504 Add follower scaler
From: @zpac Reviewed-by: @cristoval Signed-off-by:
This commit is contained in:
commit
d6658ac241
|
@ -32,6 +32,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_request_handler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/ssl_wrapper.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/leader_scaler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/follower_scaler.cc")
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "core/file_configuration.cc")
|
||||
endif()
|
||||
|
||||
|
|
|
@ -88,6 +88,16 @@ constexpr int64_t kRetryIntervalInMs = 10;
|
|||
|
||||
constexpr int64_t kThreadNum = 32;
|
||||
|
||||
// The barrier function which should be called before doing scaling out/in operations.
|
||||
// It's easy for us to scale out/in nodes after one iteration is completed and keep consistent.
|
||||
using BarrierBeforeScaleOut = std::function<void(void)>;
|
||||
using BarrierBeforeScaleIn = std::function<void(void)>;
|
||||
|
||||
// These handlers helps worker/server node to reinitialize or recover data after scaling out/in operation of scheduler
|
||||
// is done.
|
||||
using HandlerAfterScaleOut = std::function<void(void)>;
|
||||
using HandlerAfterScaleIn = std::function<void(void)>;
|
||||
|
||||
using DataPtr = std::shared_ptr<unsigned char[]>;
|
||||
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
|
||||
using Key = uint64_t;
|
||||
|
|
|
@ -326,6 +326,37 @@ bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, cons
|
|||
return res;
|
||||
}
|
||||
|
||||
bool AbstractNode::InitFollowerScaler() {
|
||||
follower_scaler_ = std::make_unique<FollowerScaler>(this);
|
||||
MS_EXCEPTION_IF_NULL(follower_scaler_);
|
||||
follower_scaler_->RegisterScaleEventCallbacks();
|
||||
return true;
|
||||
}
|
||||
|
||||
void AbstractNode::RegisterFollowerScalerBarrierBeforeScaleOut(const std::string &module,
|
||||
const BarrierBeforeScaleOut &barrier) {
|
||||
MS_EXCEPTION_IF_NULL(follower_scaler_);
|
||||
follower_scaler_->RegisterBarrierBeforeScaleOut(module, barrier);
|
||||
}
|
||||
|
||||
void AbstractNode::RegisterFollowerScalerBarrierBeforeScaleIn(const std::string &module,
|
||||
const BarrierBeforeScaleIn &barrier) {
|
||||
MS_EXCEPTION_IF_NULL(follower_scaler_);
|
||||
follower_scaler_->RegisterBarrierBeforeScaleIn(module, barrier);
|
||||
}
|
||||
|
||||
void AbstractNode::RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module,
|
||||
const HandlerAfterScaleOut &handler) {
|
||||
MS_EXCEPTION_IF_NULL(follower_scaler_);
|
||||
follower_scaler_->RegisterHandlerAfterScaleOut(module, handler);
|
||||
}
|
||||
|
||||
void AbstractNode::RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module,
|
||||
const HandlerAfterScaleIn &handler) {
|
||||
MS_EXCEPTION_IF_NULL(follower_scaler_);
|
||||
follower_scaler_->RegisterHandlerAfterScaleIn(module, handler);
|
||||
}
|
||||
|
||||
int32_t AbstractNode::worker_num() const { return worker_num_; }
|
||||
|
||||
int32_t AbstractNode::server_num() const { return server_num_; }
|
||||
|
|
|
@ -26,12 +26,14 @@
|
|||
|
||||
#include "ps/core/node.h"
|
||||
#include "ps/core/communicator/message.h"
|
||||
#include "ps/core/follower_scaler.h"
|
||||
#include "utils/ms_exception.h"
|
||||
#include "ps/constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class FollowerScaler;
|
||||
class AbstractNode : public Node {
|
||||
public:
|
||||
AbstractNode()
|
||||
|
@ -84,6 +86,17 @@ class AbstractNode : public Node {
|
|||
VectorPtr *output);
|
||||
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||
|
||||
// Initialize the scaler for server to process before/after scaling operations.
|
||||
bool InitFollowerScaler();
|
||||
|
||||
// Register barriers before scaling operations for server.
|
||||
void RegisterFollowerScalerBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier);
|
||||
void RegisterFollowerScalerBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier);
|
||||
|
||||
// Register handlers after scaling operations for server.
|
||||
void RegisterFollowerScalerHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler);
|
||||
void RegisterFollowerScalerHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler);
|
||||
|
||||
int32_t worker_num() const;
|
||||
int32_t server_num() const;
|
||||
|
||||
|
@ -187,6 +200,9 @@ class AbstractNode : public Node {
|
|||
|
||||
// Each ClusterEvent corresponds to a EventCallback to process the event.
|
||||
std::map<ClusterEvent, EventCallback> event_to_callback_;
|
||||
|
||||
// Scaler for worker/server node.
|
||||
std::unique_ptr<FollowerScaler> follower_scaler_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/follower_scaler.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_(NodeScaleState::kNormal) {
|
||||
process_before_scale_out_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scale_out_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kPreparing; });
|
||||
ProcessBeforeScaleOut();
|
||||
}
|
||||
});
|
||||
process_before_scale_in_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(scale_in_mtx_);
|
||||
scale_in_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kPreparing; });
|
||||
// In scaling in scenario, abstract node will trigger CLUSTER_SCALE_IN_DONE event in the same thread if this node
|
||||
// is the one to be scaled in, so we need to release the lock here to avoid dead lock.
|
||||
lock.unlock();
|
||||
ProcessBeforeScaleIn();
|
||||
}
|
||||
});
|
||||
|
||||
process_after_scale_out_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scale_out_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kScaling; });
|
||||
ProcessAfterScaleOut();
|
||||
}
|
||||
});
|
||||
process_after_scale_in_thread_ = std::thread([&]() {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(scale_in_mtx_);
|
||||
scale_in_cv_.wait(lock, [&]() -> bool { return scaling_state_.load() == NodeScaleState::kScaling; });
|
||||
ProcessAfterScaleIn();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterScaleEventCallbacks() {
|
||||
ready_for_scale_out_event_callback_ = [&]() -> void {
|
||||
// Notify the thread which will call the barriers.
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scaling_state_ = NodeScaleState::kPreparing;
|
||||
scale_out_cv_.notify_all();
|
||||
};
|
||||
|
||||
ready_for_scale_in_event_callback_ = [&]() -> void {
|
||||
std::unique_lock<std::mutex> lock(scale_in_mtx_);
|
||||
scaling_state_ = NodeScaleState::kPreparing;
|
||||
scale_in_cv_.notify_all();
|
||||
};
|
||||
|
||||
scale_out_done_event_callback_ = [&]() -> void {
|
||||
std::unique_lock<std::mutex> lock(scale_out_mtx_);
|
||||
scaling_state_ = NodeScaleState::kScaling;
|
||||
scale_out_cv_.notify_all();
|
||||
};
|
||||
|
||||
scale_in_done_event_callback_ = [&]() -> void {
|
||||
std::unique_lock<std::mutex> lock(scale_in_mtx_);
|
||||
scaling_state_ = NodeScaleState::kScaling;
|
||||
scale_in_cv_.notify_all();
|
||||
};
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_OUT, ready_for_scale_out_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::READY_FOR_SCALE_IN, ready_for_scale_in_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_OUT_DONE, scale_out_done_event_callback_);
|
||||
node_->RegisterEventCallback(core::ClusterEvent::CLUSTER_SCALE_IN_DONE, scale_in_done_event_callback_);
|
||||
}
|
||||
|
||||
void FollowerScaler::ProcessBeforeScaleOut() {
|
||||
for (auto &barrier : barriers_before_scale_out_) {
|
||||
MS_LOG(INFO) << "Calling barrier before scaling out for " << barrier.first;
|
||||
barrier.second();
|
||||
}
|
||||
scaling_state_ = NodeScaleState::kWaiting;
|
||||
// Notify scheduler that this node is ready for elastic scaling out.
|
||||
node_->set_ready_for_scale_out();
|
||||
}
|
||||
|
||||
void FollowerScaler::ProcessBeforeScaleIn() {
|
||||
for (auto &barrier : barriers_before_scale_in_) {
|
||||
MS_LOG(INFO) << "Calling barrier before scaling in for " << barrier.first;
|
||||
barrier.second();
|
||||
}
|
||||
scaling_state_ = NodeScaleState::kWaiting;
|
||||
// Notify scheduler that this node is ready for elastic scaling in.
|
||||
node_->set_ready_for_scale_in();
|
||||
}
|
||||
|
||||
void FollowerScaler::ProcessAfterScaleOut() {
|
||||
MS_LOG(INFO) << "Scaling out operation in scheduler is done. Do scaling out for this node.";
|
||||
for (auto &handler : handlers_after_scale_out_) {
|
||||
MS_LOG(INFO) << "Calling scaling out handler for " << handler.first;
|
||||
handler.second();
|
||||
}
|
||||
scaling_state_ = NodeScaleState::kNormal;
|
||||
// Notify scheduler that scaling out of this node is done.
|
||||
node_->set_scale_out_done();
|
||||
}
|
||||
|
||||
void FollowerScaler::ProcessAfterScaleIn() {
|
||||
MS_LOG(INFO) << "Scaling in operation in scheduler is done. Do scaling in for this node.";
|
||||
for (auto &handler : handlers_after_scale_in_) {
|
||||
MS_LOG(INFO) << "Calling scaling in handler for " << handler.first;
|
||||
handler.second();
|
||||
}
|
||||
scaling_state_ = NodeScaleState::kNormal;
|
||||
// Notify scheduler that scaling out of this node is done.
|
||||
node_->set_scale_in_done();
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier) {
|
||||
barriers_before_scale_out_.try_emplace(module, barrier);
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier) {
|
||||
barriers_before_scale_in_.try_emplace(module, barrier);
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler) {
|
||||
handlers_after_scale_out_.try_emplace(module, handler);
|
||||
}
|
||||
|
||||
void FollowerScaler::RegisterHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler) {
|
||||
handlers_after_scale_in_.try_emplace(module, handler);
|
||||
}
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_FOLLOWER_SCALER_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_FOLLOWER_SCALER_H_
|
||||
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
#include <condition_variable>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/core/abstract_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class AbstractNode;
|
||||
// Scaling state machine: kNormal->kPreparing->kWaiting->kScaling->kNormal
|
||||
enum class NodeScaleState {
|
||||
// This state means the server/worker node is not involved with scaling operations.
|
||||
kNormal,
|
||||
// This state means the server/worker node is preparing for scaling. The barriers will be called when
|
||||
// server/worker node is in this state.
|
||||
kPreparing,
|
||||
// After barriers complete, the server/worker node switches into this state. This means this node is ready for
|
||||
// scaling. When in this state, server/worker node is in safemode.
|
||||
kWaiting,
|
||||
// Server/worker node will switch to this state after scheduler's scaling out/in operation is done.
|
||||
// When in this state, server/worker node can't send/receive messages.
|
||||
kScaling
|
||||
};
|
||||
|
||||
// The class helps worker/server node to elastic scale while running a training job. In this class, the scaling events
|
||||
// are triggered by scheduler and caught by worker/server.
|
||||
|
||||
// Modules which are involved with elastic scaling should register handlers to this class. After scheduler receives
|
||||
// elastic scaling messages from user or cluster manager, it triggers events and the handlers will be called so that
|
||||
// every module's consistency is guaranteed.
|
||||
class FollowerScaler {
|
||||
public:
|
||||
explicit FollowerScaler(AbstractNode *node);
|
||||
~FollowerScaler() = default;
|
||||
|
||||
// The methods called after the events READY_FOR_SCALE_OUT/READY_FOR_SCALE_IN are triggered.
|
||||
void ProcessBeforeScaleOut();
|
||||
void ProcessBeforeScaleIn();
|
||||
|
||||
// The methods called after the events CLUSTER_SCALE_OUT_DONE/CLUSTER_SCALE_IN_DONE are triggered.
|
||||
void ProcessAfterScaleOut();
|
||||
void ProcessAfterScaleIn();
|
||||
|
||||
void RegisterBarrierBeforeScaleOut(const std::string &module, const BarrierBeforeScaleOut &barrier);
|
||||
void RegisterBarrierBeforeScaleIn(const std::string &module, const BarrierBeforeScaleIn &barrier);
|
||||
void RegisterHandlerAfterScaleOut(const std::string &module, const HandlerAfterScaleOut &handler);
|
||||
void RegisterHandlerAfterScaleIn(const std::string &module, const HandlerAfterScaleIn &handler);
|
||||
|
||||
// Register the scaling event callbacks to the node.
|
||||
void RegisterScaleEventCallbacks();
|
||||
|
||||
private:
|
||||
AbstractNode *node_;
|
||||
|
||||
std::atomic<NodeScaleState> scaling_state_;
|
||||
|
||||
// Callbacks for scaling events should not be blocked so we notify a thread to call
|
||||
// barriers(barriers_before_scale_out_/barriers_before_scale_in_) or
|
||||
// handlers(handlers_after_scale_out_/handlers_after_scale_in_).
|
||||
std::thread process_before_scale_out_thread_;
|
||||
std::thread process_before_scale_in_thread_;
|
||||
std::thread process_after_scale_out_thread_;
|
||||
std::thread process_after_scale_in_thread_;
|
||||
|
||||
// Variables for signals of scaling out/in operations.
|
||||
std::mutex scale_out_mtx_;
|
||||
std::mutex scale_in_mtx_;
|
||||
std::condition_variable scale_out_cv_;
|
||||
std::condition_variable scale_in_cv_;
|
||||
|
||||
// Barriers and handlers for scale out/in events.
|
||||
std::map<std::string, BarrierBeforeScaleOut> barriers_before_scale_out_;
|
||||
std::map<std::string, BarrierBeforeScaleIn> barriers_before_scale_in_;
|
||||
std::map<std::string, HandlerAfterScaleOut> handlers_after_scale_out_;
|
||||
std::map<std::string, HandlerAfterScaleIn> handlers_after_scale_in_;
|
||||
|
||||
std::function<void(void)> ready_for_scale_out_event_callback_;
|
||||
std::function<void(void)> ready_for_scale_in_event_callback_;
|
||||
std::function<void(void)> scale_out_done_event_callback_;
|
||||
std::function<void(void)> scale_in_done_event_callback_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_FOLLOWER_SCALER_H_
|
|
@ -207,6 +207,20 @@ bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t c
|
|||
}
|
||||
}
|
||||
|
||||
bool CollectiveOpsImpl::ReInitForScaling() {
|
||||
// If CollectiveOpsImpl is not initialized yet but the scaling event is triggered, do not throw exception.
|
||||
if (server_node_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize ring for collective communication.";
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = server_node_->server_num();
|
||||
MS_LOG(INFO) << "After scheduler scaling out, this server's rank is " << local_rank_ << ", server number is "
|
||||
<< server_num_;
|
||||
return true;
|
||||
}
|
||||
|
||||
template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
|
|
@ -44,6 +44,9 @@ class CollectiveOpsImpl {
|
|||
template <typename T>
|
||||
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
// Reinitialize the ring for collective communication after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
private:
|
||||
CollectiveOpsImpl() = default;
|
||||
~CollectiveOpsImpl() = default;
|
||||
|
|
|
@ -26,18 +26,24 @@ void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode>
|
|||
uint32_t counting_server_rank) {
|
||||
server_node_ = server_node;
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
|
||||
communicator_ =
|
||||
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
counting_server_rank_ = counting_server_rank;
|
||||
RegisterCallback();
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
communicator_ = communicator;
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"countReachThreshold",
|
||||
std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count,
|
||||
const CounterHandlers &counter_handlers) {
|
||||
if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) {
|
||||
|
@ -136,18 +142,23 @@ void DistributedCountService::ResetCounter(const std::string &name) {
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::RegisterCallback() {
|
||||
if (local_rank_ == counting_server_rank_) {
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"countReachThreshold",
|
||||
std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1));
|
||||
bool DistributedCountService::ReInitForScaling() {
|
||||
// If DistributedCountService is not initialized yet but the scaling event is triggered, do not throw exception.
|
||||
if (server_node_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// The callback of first/last event must be set in both leader server and follower servers.
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1));
|
||||
MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed count service.";
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = server_node_->server_num();
|
||||
MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
|
||||
<< server_num_;
|
||||
|
||||
// Clear old counter data of this server.
|
||||
global_current_count_.clear();
|
||||
global_threshold_count_.clear();
|
||||
counter_handlers_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
constexpr uint32_t kDefaultCountingServerRank = 0;
|
||||
constexpr auto kModuleDistributedCountService = "DistributedCountService";
|
||||
// The callbacks for the first count and last count event.
|
||||
typedef struct {
|
||||
MessageCallback first_count_handler;
|
||||
|
@ -54,6 +56,9 @@ class DistributedCountService {
|
|||
// Initialize counter service with the server node because communication is needed.
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
|
||||
|
||||
// Register message callbacks of the counting server to handle messages sent by the other servers.
|
||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
|
||||
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
|
||||
// first/last count event callbacks.
|
||||
void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers);
|
||||
|
@ -68,6 +73,9 @@ class DistributedCountService {
|
|||
// Reset the count of the name to 0.
|
||||
void ResetCounter(const std::string &name);
|
||||
|
||||
// Reinitialize counting service after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
// Returns the server rank because in some cases the callers use this rank as the 'id' for method
|
||||
// Count.
|
||||
uint32_t local_rank() { return local_rank_; }
|
||||
|
@ -78,9 +86,6 @@ class DistributedCountService {
|
|||
DistributedCountService(const DistributedCountService &) = delete;
|
||||
DistributedCountService &operator=(const DistributedCountService &) = delete;
|
||||
|
||||
// Register callbacks of the counting server to handle messages sent by the other servers.
|
||||
void RegisterCallback();
|
||||
|
||||
// Callback for the reporting count message from other servers. Only counting server will call this method.
|
||||
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
|
|
|
@ -25,16 +25,19 @@ namespace server {
|
|||
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
server_node_ = server_node;
|
||||
MS_EXCEPTION_IF_NULL(server_node);
|
||||
|
||||
communicator_ =
|
||||
std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr));
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
|
||||
InitHashRing();
|
||||
RegisterCallback();
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
communicator_ = communicator;
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -139,6 +142,24 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
|||
}
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::ReInitForScaling() {
|
||||
// If DistributedMetadataStore is not initialized yet but the scaling event is triggered, do not throw exception.
|
||||
if (server_node_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed metadata store.";
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = server_node_->server_num();
|
||||
MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
|
||||
<< server_num_;
|
||||
InitHashRing();
|
||||
|
||||
// Clear old metadata.
|
||||
metadata_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::InitHashRing() {
|
||||
router_ = std::make_shared<ConsistentHashRing>(32);
|
||||
MS_EXCEPTION_IF_NULL(router_);
|
||||
|
@ -152,14 +173,6 @@ void DistributedMetadataStore::InitHashRing() {
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::RegisterCallback() {
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1));
|
||||
communicator_->RegisterMsgCallBack(
|
||||
"getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1));
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore";
|
||||
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
|
||||
// stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind
|
||||
// of operations should be done to the metadata.
|
||||
|
@ -45,6 +46,9 @@ class DistributedMetadataStore {
|
|||
// Initialize metadata storage with the server node because communication is needed.
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||
|
||||
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
|
||||
// Register metadata for the name with the initial value. This method should be only called once for each name.
|
||||
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
|
@ -57,6 +61,9 @@ class DistributedMetadataStore {
|
|||
// Get the metadata for the name.
|
||||
PBMetadata GetMetadata(const std::string &name);
|
||||
|
||||
// Reinitialize the consistency hash ring and clear metadata after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
private:
|
||||
DistributedMetadataStore() = default;
|
||||
~DistributedMetadataStore() = default;
|
||||
|
@ -66,9 +73,6 @@ class DistributedMetadataStore {
|
|||
// Initialize the consistent hash ring for distributed storage.
|
||||
void InitHashRing();
|
||||
|
||||
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
||||
void RegisterCallback();
|
||||
|
||||
// Callback for updating metadata request sent to the server.
|
||||
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
|
||||
|
|
|
@ -41,6 +41,16 @@ void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_cou
|
|||
return;
|
||||
}
|
||||
|
||||
bool Executor::ReInitForScaling() {
|
||||
auto result = std::find_if(param_aggrs_.begin(), param_aggrs_.end(),
|
||||
[](auto param_aggr) { return !param_aggr.second->ReInitForScaling(); });
|
||||
if (result != param_aggrs_.end()) {
|
||||
MS_LOG(ERROR) << "Reinitializing aggregator of " << result->first << " for scaling failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Executor::initialized() const { return initialized_; }
|
||||
|
||||
bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) {
|
||||
|
|
|
@ -45,6 +45,9 @@ class Executor {
|
|||
// optimizer cnode's input. So we need to initialize server executor using func_graph.
|
||||
void Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count);
|
||||
|
||||
// Reinitialize parameter aggregators after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
// Called in parameter server training mode to do Push operation.
|
||||
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
|
||||
// as completed.
|
||||
|
|
|
@ -82,10 +82,34 @@ void Iteration::ProceedToNextIter(bool is_iteration_valid) {
|
|||
}
|
||||
|
||||
is_last_iteration_valid_ = is_iteration_valid;
|
||||
iteration_state_ = IterationState::kEnd;
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
|
||||
void Iteration::SetIterationRunning() {
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
}
|
||||
|
||||
void Iteration::ScalingBarrier() {
|
||||
MS_LOG(INFO) << "Starting Iteration scaling barrier.";
|
||||
while (iteration_state_.load() != IterationState::kEnd) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
MS_LOG(INFO) << "Ending Iteration scaling barrier.";
|
||||
}
|
||||
|
||||
bool Iteration::ReInitForScaling() {
|
||||
for (auto &round : rounds_) {
|
||||
if (!round->ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Reinitializing round " << round->name() << " for scaling failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
||||
|
||||
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
|
||||
|
|
|
@ -27,6 +27,13 @@
|
|||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
enum class IterationState {
|
||||
// This iteration is still in process.
|
||||
kRunning,
|
||||
// This iteration is completed and the next iteration is not started yet.
|
||||
kEnd
|
||||
};
|
||||
|
||||
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
|
||||
// Rounds, only after all the rounds are finished, this iteration is considered as completed.
|
||||
class Iteration {
|
||||
|
@ -47,12 +54,22 @@ class Iteration {
|
|||
// If the timer expires, we consider this iteration as invalid.
|
||||
void ProceedToNextIter(bool is_iteration_valid);
|
||||
|
||||
// Set current iteration state to running.
|
||||
void SetIterationRunning();
|
||||
|
||||
// The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is
|
||||
// completed.
|
||||
void ScalingBarrier();
|
||||
|
||||
// Reinitialize rounds after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &rounds();
|
||||
|
||||
bool is_last_iteration_valid() const;
|
||||
|
||||
private:
|
||||
Iteration() : iteration_num_(1), is_last_iteration_valid_(true) {
|
||||
Iteration() : iteration_state_(IterationState::kEnd), iteration_num_(1), is_last_iteration_valid_(true) {
|
||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||
}
|
||||
~Iteration() = default;
|
||||
|
@ -61,6 +78,9 @@ class Iteration {
|
|||
|
||||
std::vector<std::shared_ptr<Round>> rounds_;
|
||||
|
||||
// The iteration is either running or completed at any time.
|
||||
std::atomic<IterationState> iteration_state_;
|
||||
|
||||
// Server's current iteration number.
|
||||
size_t iteration_num_;
|
||||
|
||||
|
|
|
@ -64,6 +64,9 @@ class AggregationKernel : public CPUKernel {
|
|||
return;
|
||||
}
|
||||
|
||||
// Reinitialize aggregation kernel after scaling operations are done.
|
||||
virtual bool ReInitForScaling() { return true; }
|
||||
|
||||
// Setter and getter of kernels parameters information.
|
||||
void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; }
|
||||
const std::vector<std::string> &input_names() { return params_info_.inputs_names(); }
|
||||
|
|
|
@ -69,13 +69,13 @@ class FedAvgKernel : public AggregationKernel {
|
|||
MS_EXCEPTION_IF_NULL(weight_node);
|
||||
name_ = cnode_name + "." + weight_node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Register counter for " << name_;
|
||||
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
first_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||
if (!participated_) {
|
||||
ClearWeightAndDataSize();
|
||||
}
|
||||
};
|
||||
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
last_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
||||
size_t weight_size = weight_addr_->size;
|
||||
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
||||
|
@ -94,8 +94,8 @@ class FedAvgKernel : public AggregationKernel {
|
|||
done_ = true;
|
||||
return;
|
||||
};
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
|
||||
GenerateReuseKernelNodeInfo();
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -147,6 +147,11 @@ class FedAvgKernel : public AggregationKernel {
|
|||
return;
|
||||
}
|
||||
|
||||
bool ReInitForScaling() override {
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
void GenerateReuseKernelNodeInfo() override {
|
||||
MS_LOG(INFO) << "FedAvg reuse 'weight' of the kernel node.";
|
||||
|
@ -170,6 +175,9 @@ class FedAvgKernel : public AggregationKernel {
|
|||
return;
|
||||
}
|
||||
|
||||
MessageCallback first_cnt_handler_;
|
||||
MessageCallback last_cnt_handler_;
|
||||
|
||||
// The trainable parameter index of the kernel node which is parsed from the frontend func_graph.
|
||||
size_t cnode_weight_idx_;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ps/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -91,6 +92,8 @@ bool StartFLJobKernel::Reset() {
|
|||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
// The first startFLJob request means a new iteration starts running.
|
||||
Iteration::GetInstance().SetIterationRunning();
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) {
|
||||
|
|
|
@ -47,6 +47,16 @@ bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::ReInitForScaling() {
|
||||
auto result = std::find_if(aggregation_kernel_parameters_.begin(), aggregation_kernel_parameters_.end(),
|
||||
[](auto aggregation_kernel) { return !aggregation_kernel.first->ReInitForScaling(); });
|
||||
if (result != aggregation_kernel_parameters_.end()) {
|
||||
MS_LOG(ERROR) << "Reinitializing aggregation kernel after scaling failed";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) {
|
||||
std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
|
||||
for (const auto &data : new_data) {
|
||||
|
|
|
@ -64,6 +64,9 @@ class ParameterAggregator {
|
|||
// The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful.
|
||||
bool Init(const CNodePtr &cnode, size_t threshold_count = 0);
|
||||
|
||||
// Reinitialize the parameter aggregator after scaling operations are done.
|
||||
bool ReInitForScaling();
|
||||
|
||||
// Update old data stored in ParameterAggregator with new data.
|
||||
// The data could have many meanings: weights, gradients, learning_rate, momentum, etc.
|
||||
bool UpdateData(const std::map<std::string, Address> &new_data);
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
#include "ps/server/round.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ps/server/server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
class Server;
|
||||
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count)
|
||||
: name_(name),
|
||||
check_timeout_(check_timeout),
|
||||
|
@ -85,6 +87,14 @@ void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &messa
|
|||
return;
|
||||
}
|
||||
|
||||
// If the server is still in the process of scaling, refuse the request.
|
||||
if (Server::GetInstance().IsSafeMode()) {
|
||||
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
|
||||
std::string reason = "The cluster is in safemode.";
|
||||
communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
AddressPtr input = std::make_shared<Address>();
|
||||
AddressPtr output = std::make_shared<Address>();
|
||||
input->addr = message->data();
|
||||
|
|
|
@ -40,6 +40,23 @@ class Round {
|
|||
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb);
|
||||
|
||||
// Reinitialize count service and round kernel of this round after scaling operations are done.
|
||||
bool ReInitForScaling() {
|
||||
if (check_count_) {
|
||||
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
|
||||
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
|
||||
{first_count_handler, last_count_handler});
|
||||
}
|
||||
|
||||
if (kernel_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
|
||||
return false;
|
||||
}
|
||||
kernel_->InitKernel(threshold_count_);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Bind a round kernel to this Round. This method should be called after Initialize.
|
||||
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
|
||||
|
||||
|
|
|
@ -76,10 +76,12 @@ void Server::Run() {
|
|||
InitServerContext();
|
||||
InitCluster();
|
||||
InitIteration();
|
||||
RegisterCommCallbacks();
|
||||
StartCommunicator();
|
||||
InitExecutor();
|
||||
RegisterRoundKernel();
|
||||
MS_LOG(INFO) << "Server started successfully.";
|
||||
safemode_ = false;
|
||||
|
||||
// Wait communicators to stop so the main thread is blocked.
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
|
@ -89,6 +91,8 @@ void Server::Run() {
|
|||
return;
|
||||
}
|
||||
|
||||
bool Server::IsSafeMode() { return safemode_.load(); }
|
||||
|
||||
void Server::InitServerContext() {
|
||||
PSContext::instance()->GenerateResetterRound();
|
||||
scheduler_ip_ = PSContext::instance()->scheduler_host();
|
||||
|
@ -122,34 +126,6 @@ bool Server::InitCommunicatorWithServer() {
|
|||
communicator_with_server_ =
|
||||
server_node_->GetOrCreateTcpComm(scheduler_ip_, scheduler_port_, worker_num_, server_num_, task_executor_);
|
||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||
|
||||
// Set exception event callbacks for server.
|
||||
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
|
||||
"started during network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
tcp_comm->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR)
|
||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
"network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -194,13 +170,70 @@ void Server::InitIteration() {
|
|||
return;
|
||||
}
|
||||
|
||||
void Server::RegisterCommCallbacks() {
|
||||
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
|
||||
// rounds' callbacks.
|
||||
|
||||
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
|
||||
// Set message callbacks for server-to-server communication.
|
||||
DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm);
|
||||
DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm);
|
||||
|
||||
// Set exception event callbacks for server.
|
||||
RegisterExceptionEventCallback(tcp_comm);
|
||||
|
||||
if (!server_node_->InitFollowerScaler()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
|
||||
return;
|
||||
}
|
||||
// Set scaling barriers before scaling.
|
||||
server_node_->RegisterFollowerScalerBarrierBeforeScaleOut("ServerPipeline",
|
||||
std::bind(&Server::ProcessBeforeScalingOut, this));
|
||||
server_node_->RegisterFollowerScalerBarrierBeforeScaleIn("ServerPipeline",
|
||||
std::bind(&Server::ProcessBeforeScalingIn, this));
|
||||
// Set handlers after scheduler scaling operations are done.
|
||||
server_node_->RegisterFollowerScalerHandlerAfterScaleOut("ServerPipeline",
|
||||
std::bind(&Server::ProcessAfterScalingOut, this));
|
||||
server_node_->RegisterFollowerScalerHandlerAfterScaleIn("ServerPipeline",
|
||||
std::bind(&Server::ProcessAfterScalingIn, this));
|
||||
}
|
||||
|
||||
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::CLUSTER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event CLUSTER_TIMEOUT is captured. This is because some nodes(Scheduler/Server/Worker) are not "
|
||||
"started during network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR)
|
||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
"network building phase.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
}
|
||||
|
||||
void Server::InitExecutor() {
|
||||
if (executor_threshold_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
|
||||
return;
|
||||
}
|
||||
// The train engine instance is used in both push-type and pull-type kernels,
|
||||
// so the required_cnt of these kernels must be the same as update_model_threshold_.
|
||||
// so the required_cnt of these kernels must be the same as executor_threshold_.
|
||||
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
|
||||
Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
|
||||
ModelStore::GetInstance().Initialize();
|
||||
|
@ -247,6 +280,79 @@ void Server::StartCommunicator() {
|
|||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||
}
|
||||
|
||||
void Server::ProcessBeforeScalingOut() {
|
||||
iteration_->ScalingBarrier();
|
||||
safemode_ = true;
|
||||
}
|
||||
|
||||
void Server::ProcessBeforeScalingIn() {
|
||||
iteration_->ScalingBarrier();
|
||||
safemode_ = true;
|
||||
}
|
||||
|
||||
void Server::ProcessAfterScalingOut() {
|
||||
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!iteration_->ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Iteration reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Executor reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
safemode_ = false;
|
||||
}
|
||||
|
||||
void Server::ProcessAfterScalingIn() {
|
||||
if (server_node_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (server_node_->rank_id() == UINT32_MAX) {
|
||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
return;
|
||||
}
|
||||
|
||||
// If the server is not the one to be scaled in, reintialize modules and recover service.
|
||||
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!iteration_->ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Iteration reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(ERROR) << "Executor reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
safemode_ = false;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,8 @@ class Server {
|
|||
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
|
||||
void Run();
|
||||
|
||||
bool IsSafeMode();
|
||||
|
||||
private:
|
||||
Server()
|
||||
: server_node_(nullptr),
|
||||
|
@ -58,6 +60,7 @@ class Server {
|
|||
communicator_with_server_(nullptr),
|
||||
communicators_with_worker_({}),
|
||||
iteration_(nullptr),
|
||||
safemode_(true),
|
||||
scheduler_ip_(""),
|
||||
scheduler_port_(0),
|
||||
server_num_(0),
|
||||
|
@ -77,6 +80,13 @@ class Server {
|
|||
// Initialize iteration with rounds. Which rounds to use could be set by ps_context as well.
|
||||
void InitIteration();
|
||||
|
||||
// Register all message and event callbacks for communicators(TCP and HTTP). This method must be called before
|
||||
// communicators are started.
|
||||
void RegisterCommCallbacks();
|
||||
|
||||
// Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
|
||||
void RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
|
||||
// Initialize executor according to the server mode.
|
||||
void InitExecutor();
|
||||
|
||||
|
@ -86,6 +96,14 @@ class Server {
|
|||
// The communicators should be started after all initializations are completed.
|
||||
void StartCommunicator();
|
||||
|
||||
// The barriers before scaling operations.
|
||||
void ProcessBeforeScalingOut();
|
||||
void ProcessBeforeScalingIn();
|
||||
|
||||
// The handlers after scheduler's scaling operations are done.
|
||||
void ProcessAfterScalingOut();
|
||||
void ProcessAfterScalingIn();
|
||||
|
||||
// The server node is initialized in Server.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
|
||||
|
@ -119,6 +137,10 @@ class Server {
|
|||
// Iteration consists of multiple kinds of rounds.
|
||||
Iteration *iteration_;
|
||||
|
||||
// The flag that represents whether server is in safemode.
|
||||
// If true, the server is not available to workers and clients.
|
||||
std::atomic_bool safemode_;
|
||||
|
||||
// Variables set by ps context.
|
||||
std::string scheduler_ip_;
|
||||
uint16_t scheduler_port_;
|
||||
|
|
|
@ -156,10 +156,13 @@ np.random.seed(0)
|
|||
while True:
|
||||
url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
|
||||
print("start url is ", url1)
|
||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||
x = session.post(url1, data=build_start_fl_job(current_iteration))
|
||||
while x.text == "The cluster is in safemode.":
|
||||
x = session.post(url1, data=build_start_fl_job(current_iteration))
|
||||
|
||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
|
||||
x = requests.post(url1, data=build_start_fl_job(current_iteration))
|
||||
x = session.post(url1, data=build_start_fl_job(current_iteration))
|
||||
rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
|
||||
print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
|
||||
print("iteration is", rsp_fl_job.Iteration())
|
||||
|
@ -176,6 +179,10 @@ while True:
|
|||
url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
|
||||
print("req get model iteration:", current_iteration, ", id:", args.pid)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
while x.text == "The cluster is in safemode.":
|
||||
time.sleep(0.2)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
|
||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||
print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
|
||||
sys.stdout.flush()
|
||||
|
@ -190,6 +197,10 @@ while True:
|
|||
while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
|
||||
time.sleep(0.2)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
while x.text == "The cluster is in safemode.":
|
||||
time.sleep(0.2)
|
||||
x = session.post(url3, data=build_get_model(current_iteration))
|
||||
|
||||
rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
|
||||
if rsp_get_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
|
||||
next_req_timestamp = int(rsp_get_model.Timestamp().decode('utf-8'))
|
||||
|
|
Loading…
Reference in New Issue