Synchronize bug fix from enterprise branch

Merge pull request  from ZPaC/master-fix-follower-scaler-issue
This commit is contained in:
i-robot 2021-06-11 09:27:47 +08:00 committed by Gitee
commit 3f7e42e80c
25 changed files with 400 additions and 85 deletions

View File

@ -30,6 +30,8 @@
namespace mindspore {
namespace kernel {
// The duration between two downloading requests when return code is ResponseCode_SucNotReady.
constexpr int kRetryDurationOfPullWeights = 200;
template <typename T>
class FusedPullWeightKernel : public CPUKernel {
public:
@ -79,7 +81,7 @@ class FusedPullWeightKernel : public CPUKernel {
pull_weight_rsp = flatbuffers::GetRoot<schema::ResponsePullWeight>(pull_weight_rsp_msg->data());
retcode = pull_weight_rsp->retcode();
if (retcode == schema::ResponseCode_SucNotReady) {
std::this_thread::sleep_for(std::chrono::milliseconds(200));
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
continue;
} else if (retcode != schema::ResponseCode_SUCCEED) {
MS_LOG(EXCEPTION) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode()
@ -104,8 +106,7 @@ class FusedPullWeightKernel : public CPUKernel {
return false;
}
}
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed.";
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
return true;
}

View File

@ -63,7 +63,8 @@ class FusedPushWeightKernel : public CPUKernel {
return false;
}
for (uint32_t i = 0; i < ps::PSContext::instance()->server_num(); i++) {
// The server number may change after scaling in/out.
for (uint32_t i = 0; i < ps::worker::FLWorker::GetInstance().server_num(); i++) {
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
if (!ps::worker::FLWorker::GetInstance().SendToServer(
i, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPushWeight, &push_weight_rsp_msg)) {
@ -81,8 +82,7 @@ class FusedPushWeightKernel : public CPUKernel {
return false;
}
}
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed.";
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
return true;
}

View File

@ -658,7 +658,7 @@ bool StartServerAction(const ResourcePtr &res) {
{"updateModel", true, update_model_time_window, true, update_model_threshold},
{"getModel"},
{"pullWeight"},
{"pushWeight", false, 3000, true, server_num}};
{"pushWeight", false, 3000, true, server_num, true}};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -95,16 +95,6 @@ constexpr uint32_t kCheckRegisteredRetryCount = 30;
// The timeout interval for judging whether all nodes are successfully registered.
constexpr uint32_t kCheckRegisteredIntervalInMs = 1000;
// The barrier function which should be called before doing scaling out/in operations.
// It's easy for us to scale out/in nodes after one iteration is completed and keep consistent.
using BarrierBeforeScaleOut = std::function<void(void)>;
using BarrierBeforeScaleIn = std::function<void(void)>;
// These handlers helps worker/server node to reinitialize or recover data after scaling out/in operation of scheduler
// is done.
using HandlerAfterScaleOut = std::function<void(void)>;
using HandlerAfterScaleIn = std::function<void(void)>;
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;
@ -158,6 +148,20 @@ const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum
{kSparseLazyAdam, kSparseAdamPSSendIdx},
{kSparseFtrl, kSparseFtrlPSSendIdx}};
// The barrier function which should be called before doing scaling out/in operations.
// It's easy for us to scale out/in nodes after one iteration is completed and keep consistent.
using BarrierBeforeScaleOut = std::function<void(void)>;
using BarrierBeforeScaleIn = std::function<void(void)>;
// These handlers helps worker/server node to reinitialize or recover data after scaling out/in operation of scheduler
// is done.
using HandlerAfterScaleOut = std::function<void(void)>;
using HandlerAfterScaleIn = std::function<void(void)>;
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
enum class CustomEvent { kIterationRunning = 0, kIterationCompleted };
#define EXC_IF_VEC_IDX_OOB(vec, idx) \
{ \
size_t vec_size = vec.size(); \

View File

@ -46,7 +46,8 @@ enum class TcpUserCommand {
kUpdateMetadata,
kCounterEvent,
kPullWeight,
kPushWeight
kPushWeight,
kSyncIteration
};
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
@ -58,8 +59,9 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kGetMetadata, "getMetadata"},
{TcpUserCommand::kUpdateMetadata, "updateMetadata"},
{TcpUserCommand::kCounterEvent, "counterEvent"},
{TcpUserCommand::kPullWeight, "PullWeight"},
{TcpUserCommand::kPushWeight, "PushWeight"}};
{TcpUserCommand::kPullWeight, "pullWeight"},
{TcpUserCommand::kPushWeight, "pushWeight"},
{TcpUserCommand::kSyncIteration, "syncIteration"}};
class TcpCommunicator : public CommunicatorBase {
public:

View File

@ -28,6 +28,8 @@ FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_
ProcessBeforeScaleOut();
}
});
process_before_scale_out_thread_.detach();
process_before_scale_in_thread_ = std::thread([&]() {
while (true) {
std::unique_lock<std::mutex> lock(scale_in_mtx_);
@ -38,6 +40,7 @@ FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_
ProcessBeforeScaleIn();
}
});
process_before_scale_in_thread_.detach();
process_after_scale_out_thread_ = std::thread([&]() {
while (true) {
@ -46,6 +49,8 @@ FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_
ProcessAfterScaleOut();
}
});
process_after_scale_out_thread_.detach();
process_after_scale_in_thread_ = std::thread([&]() {
while (true) {
std::unique_lock<std::mutex> lock(scale_in_mtx_);
@ -53,6 +58,7 @@ FollowerScaler::FollowerScaler(AbstractNode *node) : node_(node), scaling_state_
ProcessAfterScaleIn();
}
});
process_after_scale_in_thread_.detach();
}
void FollowerScaler::RegisterScaleEventCallbacks() {

View File

@ -153,3 +153,13 @@ message PBMetadataWithName {
string name = 1;
PBMetadata metadata = 2;
}
message SyncIterationRequest {
// The rank of the server which sends this synchronizing iteration request to the leader server.
uint32 rank = 1;
}
message SyncIterationResponse {
// The current iteration number.
uint64 iteration = 1;
}

View File

@ -215,7 +215,7 @@ bool CollectiveOpsImpl::ReInitForScaling() {
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize ring for collective communication.";
local_rank_ = server_node_->rank_id();
server_num_ = server_node_->server_num();
server_num_ = IntToUint(server_node_->server_num());
MS_LOG(INFO) << "After scheduler scaling out, this server's rank is " << local_rank_ << ", server number is "
<< server_num_;
return true;

View File

@ -45,11 +45,18 @@ enum CommType { HTTP = 0, TCP };
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
struct RoundConfig {
// The name of round. Please refer to round kernel *.cc files.
std::string name;
// Whether this round has the time window limit.
bool check_timeout = false;
// The length of the time window. Only used when check_timeout is set to true.
size_t time_window = 3000;
// Whether this round has to check the request count has reached the threshold.
bool check_count = false;
// This round's request threshold count. Only used when check_count is set to true.
size_t threshold_count = 0;
// Whether this round uses the server as threshold count. This is vital for some rounds in elastic scaling scenario.
bool server_num_as_threshold = false;
};
using mindspore::kernel::Address;

View File

@ -20,14 +20,13 @@ namespace mindspore {
namespace ps {
namespace server {
bool ConsistentHashRing::Insert(uint32_t rank) {
std::string physical_node_hash_key = std::to_string(rank);
for (uint32_t i = 0; i < virtual_node_num_; i++) {
physical_node_hash_key += "#" + std::to_string(i);
MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank;
std::string physical_node_hash_key = std::to_string(rank) + "#" + std::to_string(i);
size_t hash_value = std::hash<std::string>()(physical_node_hash_key);
MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank << ", hash value is "
<< hash_value;
if (ring_.count(hash_value) != 0) {
MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring.";
MS_LOG(INFO) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring.";
continue;
}
ring_[hash_value] = rank;

View File

@ -150,7 +150,7 @@ bool DistributedCountService::ReInitForScaling() {
MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed count service.";
local_rank_ = server_node_->rank_id();
server_num_ = server_node_->server_num();
server_num_ = IntToUint(server_node_->server_num());
MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
<< server_num_;

View File

@ -150,7 +150,7 @@ bool DistributedMetadataStore::ReInitForScaling() {
MS_LOG(INFO) << "Cluster scaling completed. Reinitialize for distributed metadata store.";
local_rank_ = server_node_->rank_id();
server_num_ = server_node_->server_num();
server_num_ = IntToUint(server_node_->server_num());
MS_LOG(INFO) << "After scheduler scaling, this server's rank is " << local_rank_ << ", server number is "
<< server_num_;
InitHashRing();

View File

@ -16,6 +16,7 @@
#include "ps/server/iteration.h"
#include <memory>
#include <string>
#include <vector>
#include <numeric>
#include "ps/server/model_store.h"
@ -23,6 +24,22 @@
namespace mindspore {
namespace ps {
namespace server {
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
communicator_->RegisterMsgCallBack("syncIteraion",
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
}
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
std::bind(&Iteration::HandleIterationRunningEvent, this));
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
std::bind(&Iteration::HandleIterationCompletedEvent, this));
}
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
MS_EXCEPTION_IF_NULL(round);
rounds_.push_back(round);
@ -57,6 +74,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
void Iteration::ProceedToNextIter(bool is_iteration_valid) {
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
is_last_iteration_valid_ = is_iteration_valid;
if (is_iteration_valid) {
// Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel();
@ -81,8 +99,7 @@ void Iteration::ProceedToNextIter(bool is_iteration_valid) {
ModelStore::GetInstance().Reset();
}
is_last_iteration_valid_ = is_iteration_valid;
iteration_state_ = IterationState::kEnd;
SetIterationCompleted();
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n";
}
@ -90,19 +107,46 @@ void Iteration::ProceedToNextIter(bool is_iteration_valid) {
void Iteration::SetIterationRunning() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " start running.";
iteration_state_ = IterationState::kRunning;
if (server_node_ == nullptr) {
MS_LOG(ERROR) << "Server node is empty.";
return;
}
if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning));
}
}
void Iteration::SetIterationCompleted() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
iteration_state_ = IterationState::kCompleted;
if (server_node_ == nullptr) {
MS_LOG(ERROR) << "Server node is empty.";
return;
}
if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted));
}
}
void Iteration::ScalingBarrier() {
MS_LOG(INFO) << "Starting Iteration scaling barrier.";
while (iteration_state_.load() != IterationState::kEnd) {
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
MS_LOG(INFO) << "Ending Iteration scaling barrier.";
}
bool Iteration::ReInitForScaling() {
bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
if (server_rank != kLeaderServerRank) {
if (!SyncIteration(server_rank)) {
MS_LOG(ERROR) << "Synchronizing iteration failed.";
return false;
}
}
for (auto &round : rounds_) {
if (!round->ReInitForScaling()) {
if (!round->ReInitForScaling(server_num)) {
MS_LOG(ERROR) << "Reinitializing round " << round->name() << " for scaling failed.";
return false;
}
@ -113,6 +157,41 @@ bool Iteration::ReInitForScaling() {
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
bool Iteration::SyncIteration(uint32_t rank) {
SyncIterationRequest sync_iter_req;
sync_iter_req.set_rank(rank);
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
if (communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
&sync_iter_rsp_msg)) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
return false;
}
SyncIterationResponse sync_iter_rsp;
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
<< sync_iter_rsp.iteration();
return true;
}
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
}
SyncIterationRequest sync_iter_req;
sync_iter_req.ParseFromArray(message->data(), message->len());
uint32_t rank = sync_iter_req.rank();
MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank;
SyncIterationResponse sync_iter_rsp;
sync_iter_rsp.set_iteration(iteration_num_);
std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
}
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -31,7 +31,7 @@ enum class IterationState {
// This iteration is still in process.
kRunning,
// This iteration is completed and the next iteration is not started yet.
kEnd
kCompleted
};
// In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of
@ -43,6 +43,12 @@ class Iteration {
return instance;
}
// Register callbacks for other servers to synchronize iteration information from leader server.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
// Register event callbacks for iteration state synchronization.
void RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node);
// Add a round for the iteration. This method will be called multiple times for each round.
void AddRound(const std::shared_ptr<Round> &round);
@ -54,28 +60,49 @@ class Iteration {
// If the timer expires, we consider this iteration as invalid.
void ProceedToNextIter(bool is_iteration_valid);
// Set current iteration state to running.
// Set current iteration state to running and trigger events about kIterationRunning.
void SetIterationRunning();
// Set current iteration state to completed and trigger the event about kIterationCompleted.
void SetIterationCompleted();
// The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is
// completed.
void ScalingBarrier();
// Reinitialize rounds after scaling operations are done.
bool ReInitForScaling();
// The server number after scaling is required in some rounds.
bool ReInitForScaling(uint32_t server_num, uint32_t server_rank);
const std::vector<std::shared_ptr<Round>> &rounds();
bool is_last_iteration_valid() const;
private:
Iteration() : iteration_state_(IterationState::kEnd), iteration_num_(1), is_last_iteration_valid_(true) {
Iteration()
: server_node_(nullptr),
communicator_(nullptr),
iteration_state_(IterationState::kCompleted),
iteration_num_(1),
is_last_iteration_valid_(true) {
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
}
~Iteration() = default;
Iteration(const Iteration &) = delete;
Iteration &operator=(const Iteration &) = delete;
// The server does not need to handle the iteration events for now.
void HandleIterationRunningEvent() {}
void HandleIterationCompletedEvent() {}
// Synchronize iteration form the leader server(Rank 0).
bool SyncIteration(uint32_t rank);
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message);
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
// All the rounds in the server.
std::vector<std::shared_ptr<Round>> rounds_;
// The iteration is either running or completed at any time.

View File

@ -48,7 +48,8 @@ bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
const schema::RequestPullWeight *pull_weight_req = flatbuffers::GetRoot<schema::RequestPullWeight>(req_data);
if (pull_weight_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestPullWeight";
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, {});
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num(),
{});
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
@ -68,7 +69,7 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::
if (pull_weight_iter != current_iter) {
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
" is invalid. Server current iteration: " + std::to_string(current_iter);
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, feature_maps);
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps);
MS_LOG(WARNING) << reason;
return;
}
@ -81,7 +82,7 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::
if (!executor_->IsWeightAggrDone(weight_names)) {
retry_count_++;
std::string reason = "The aggregation for the weights is not done yet.";
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, feature_maps);
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);
if (retry_count_ % 10 == 0) {
MS_LOG(WARNING) << reason << " Retry count is " << retry_count_;
}
@ -91,19 +92,20 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::
feature_maps = executor_->HandlePullWeight(weight_names);
if (feature_maps.empty()) {
std::string reason = "The feature_map is empty for the given weight names.";
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, feature_maps);
BuildPullWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter, feature_maps);
MS_LOG(WARNING) << reason;
return;
}
MS_LOG(INFO) << "Pulling weight for iteration " << current_iter << " succeeds.";
BuildPullWeightRsp(fbb, schema::ResponseCode_SUCCEED,
"Pull weights by weight names for iteration " + std::to_string(pull_weight_iter) + " success.",
feature_maps);
"Pulling weight by weight names for iteration " + std::to_string(pull_weight_iter) + " success.",
current_iter, feature_maps);
return;
}
void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason,
const std::string &reason, size_t iteration,
const std::map<std::string, AddressPtr> &feature_maps) {
auto fbs_reason = fbb->CreateString(reason);
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
@ -119,6 +121,7 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
rsp_pull_weight_builder.add_retcode(retcode);
rsp_pull_weight_builder.add_reason(fbs_reason);
rsp_pull_weight_builder.add_iteration(iteration);
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
auto rsp_pull_weight = rsp_pull_weight_builder.Finish();
fbb->Finish(rsp_pull_weight);

View File

@ -43,7 +43,7 @@ class PullWeightKernel : public RoundKernel {
private:
void PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req);
void BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
const std::map<std::string, AddressPtr> &feature_maps);
size_t iteration, const std::map<std::string, AddressPtr> &feature_maps);
Executor *executor_;

View File

@ -43,7 +43,7 @@ bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
const schema::RequestPushWeight *push_weight_req = flatbuffers::GetRoot<schema::RequestPushWeight>(req_data);
if (push_weight_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestPushWeight";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason);
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, LocalMetaStore::GetInstance().curr_iter_num());
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
@ -72,31 +72,33 @@ void PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::
return;
}
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
if (iteration != current_iter) {
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num());
BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason);
", current iteration:" + std::to_string(current_iter);
BuildPushWeightRsp(fbb, schema::ResponseCode_OutOfTime, reason, current_iter);
MS_LOG(ERROR) << reason;
return;
}
std::map<std::string, Address> upload_feature_map = ParseFeatureMap(push_weight_req);
if (upload_feature_map.empty()) {
std::string reason = "PushWeight overwrite feature_map is empty.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason);
std::string reason = "PushWeight feature_map is empty.";
BuildPushWeightRsp(fbb, schema::ResponseCode_RequestError, reason, current_iter);
MS_LOG(ERROR) << reason;
return;
}
if (!executor_->HandlePushWeight(upload_feature_map)) {
std::string reason = "OverwriteWeights failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason);
std::string reason = "Pushing weight failed.";
BuildPushWeightRsp(fbb, schema::ResponseCode_SystemError, reason, current_iter);
MS_LOG(ERROR) << reason;
return;
}
MS_LOG(INFO) << "Pushing weight for iteration " << current_iter << " succeeds.";
DistributedCountService::GetInstance().Count(name_, std::to_string(local_rank_));
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.");
BuildPushWeightRsp(fbb, schema::ResponseCode_SUCCEED, "PushWeight succeed.", current_iter);
return;
}
@ -114,11 +116,12 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R
}
void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason) {
const std::string &reason, size_t iteration) {
auto fbs_reason = fbb->CreateString(reason);
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
rsp_push_weight_builder.add_retcode(retcode);
rsp_push_weight_builder.add_reason(fbs_reason);
rsp_push_weight_builder.add_iteration(iteration);
auto rsp_push_weight = rsp_push_weight_builder.Finish();
fbb->Finish(rsp_push_weight);
return;

View File

@ -44,8 +44,8 @@ class PushWeightKernel : public RoundKernel {
private:
void PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason);
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
size_t iteration);
Executor *executor_;
uint32_t local_rank_;

View File

@ -23,12 +23,14 @@ namespace mindspore {
namespace ps {
namespace server {
class Server;
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count)
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
bool server_num_as_threshold)
: name_(name),
check_timeout_(check_timeout),
time_window_(time_window),
check_count_(check_count),
threshold_count_(threshold_count) {}
threshold_count_(threshold_count),
server_num_as_threshold_(server_num_as_threshold) {}
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb) {
@ -73,6 +75,27 @@ void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicat
}
}
bool Round::ReInitForScaling(uint32_t server_num) {
// If this round requires up-to-date server number as threshold count, update threshold_count_.
if (server_num_as_threshold_) {
MS_LOG(INFO) << "Round " << name_ << " uses up-to-date server number " << server_num << " as its threshold count.";
threshold_count_ = server_num;
}
if (check_count_) {
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
{first_count_handler, last_count_handler});
}
if (kernel_ == nullptr) {
MS_LOG(ERROR) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
return false;
}
kernel_->InitKernel(threshold_count_);
return true;
}
void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
MS_EXCEPTION_IF_NULL(kernel);
kernel_ = kernel;

View File

@ -34,28 +34,14 @@ namespace server {
class Round {
public:
explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000,
bool check_count = false, size_t threshold_count = 8);
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
~Round() = default;
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb);
// Reinitialize count service and round kernel of this round after scaling operations are done.
bool ReInitForScaling() {
if (check_count_) {
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
{first_count_handler, last_count_handler});
}
if (kernel_ == nullptr) {
MS_LOG(ERROR) << "Reinitializing for round " << name_ << " failed: round kernel is nullptr.";
return false;
}
kernel_->InitKernel(threshold_count_);
return true;
}
bool ReInitForScaling(uint32_t server_num);
// Bind a round kernel to this Round. This method should be called after Initialize.
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
@ -94,6 +80,9 @@ class Round {
// the round message count has reached threshold_count_.
size_t threshold_count_;
// Whether this round uses the server number as its threshold count.
bool server_num_as_threshold_;
std::shared_ptr<core::CommunicatorBase> communicator_;
// The round kernel for this Round.

View File

@ -155,11 +155,13 @@ void Server::InitIteration() {
// 1.Add rounds to the iteration according to the server mode.
for (const RoundConfig &config : rounds_config_) {
std::shared_ptr<Round> round = std::make_shared<Round>(config.name, config.check_timeout, config.time_window,
config.check_count, config.threshold_count);
std::shared_ptr<Round> round =
std::make_shared<Round>(config.name, config.check_timeout, config.time_window, config.check_count,
config.threshold_count, config.server_num_as_threshold);
MS_LOG(INFO) << "Add round " << config.name << ", check_timeout: " << config.check_timeout
<< ", time window: " << config.time_window << ", check_count: " << config.check_count
<< ", threshold: " << config.threshold_count;
<< ", threshold: " << config.threshold_count
<< ", server_num_as_threshold: " << config.server_num_as_threshold;
iteration_->AddRound(round);
}
@ -180,6 +182,8 @@ void Server::RegisterCommCallbacks() {
// Set message callbacks for server-to-server communication.
DistributedMetadataStore::GetInstance().RegisterMessageCallback(tcp_comm);
DistributedCountService::GetInstance().RegisterMessageCallback(tcp_comm);
iteration_->RegisterMessageCallback(tcp_comm);
iteration_->RegisterEventCallback(server_node_);
// Set exception event callbacks for server.
RegisterExceptionEventCallback(tcp_comm);
@ -284,6 +288,10 @@ void Server::ProcessBeforeScalingIn() {
}
void Server::ProcessAfterScalingOut() {
if (server_node_ == nullptr) {
return;
}
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
MS_LOG(ERROR) << "DistributedMetadataStore reinitializing failed.";
return;
@ -296,7 +304,7 @@ void Server::ProcessAfterScalingOut() {
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
return;
}
if (!iteration_->ReInitForScaling()) {
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
MS_LOG(ERROR) << "Iteration reinitializing failed.";
return;
}
@ -334,7 +342,7 @@ void Server::ProcessAfterScalingIn() {
MS_LOG(ERROR) << "DistributedCountService reinitializing failed.";
return;
}
if (!iteration_->ReInitForScaling()) {
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
MS_LOG(ERROR) << "Iteration reinitializing failed.";
return;
}

View File

@ -35,7 +35,11 @@ void FLWorker::Run() {
MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
<< ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
<< ", Scheduler port:" << scheduler_port_;
worker_node_ = std::make_shared<core::WorkerNode>();
MS_EXCEPTION_IF_NULL(worker_node_);
InitializeFollowerScaler();
worker_node_->Start();
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
return;
@ -43,6 +47,11 @@ void FLWorker::Run() {
bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output) {
// If the worker is in safemode, do not communicate with server.
while (safemode_.load()) {
std::this_thread::yield();
}
std::shared_ptr<unsigned char[]> message;
std::unique_ptr<unsigned char[]> message_addr = std::make_unique<unsigned char[]>(size);
MS_EXCEPTION_IF_NULL(message_addr);
@ -58,10 +67,12 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
}
if (output != nullptr) {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}
do {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}
} while (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode);
} else {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
@ -70,6 +81,100 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
}
return true;
}
uint32_t FLWorker::server_num() const { return server_num_; }
uint32_t FLWorker::worker_num() const { return worker_num_; }
uint64_t FLWorker::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
void FLWorker::InitializeFollowerScaler() {
if (!worker_node_->InitFollowerScaler()) {
MS_LOG(EXCEPTION) << "Initializing follower elastic scaler failed.";
return;
}
// Set scaling barriers before scaling.
worker_node_->RegisterFollowerScalerBarrierBeforeScaleOut("WorkerPipeline",
std::bind(&FLWorker::ProcessBeforeScalingOut, this));
worker_node_->RegisterFollowerScalerBarrierBeforeScaleIn("WorkerPipeline",
std::bind(&FLWorker::ProcessBeforeScalingOut, this));
// Set handlers after scheduler scaling operations are done.
worker_node_->RegisterFollowerScalerHandlerAfterScaleOut("WorkerPipeline",
std::bind(&FLWorker::ProcessAfterScalingOut, this));
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
std::bind(&FLWorker::ProcessAfterScalingIn, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
std::bind(&FLWorker::HandleIterationRunningEvent, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
}
void FLWorker::HandleIterationRunningEvent() {
MS_LOG(INFO) << "Worker iteration starts, safemode is " << safemode_.load();
iteration_state_ = IterationState::kRunning;
if (safemode_.load() == true) {
safemode_ = false;
}
}
void FLWorker::HandleIterationCompletedEvent() {
MS_LOG(INFO) << "Worker iteration completes";
iteration_state_ = IterationState::kCompleted;
}
void FLWorker::ProcessBeforeScalingOut() {
MS_LOG(INFO) << "Starting Worker scaling out barrier.";
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
MS_LOG(INFO) << "Ending Worker scaling out barrier. Switch to safemode.";
safemode_ = true;
}
void FLWorker::ProcessBeforeScalingIn() {
MS_LOG(INFO) << "Starting Worker scaling in barrier.";
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
MS_LOG(INFO) << "Ending Worker scaling in barrier. Switch to safemode.";
safemode_ = true;
}
void FLWorker::ProcessAfterScalingOut() {
if (worker_node_ = nullptr) {
return;
}
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
server_num_ = worker_node_->server_num();
worker_num_ = worker_node_->worker_num();
MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
<< server_num_ << ". Exit safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
safemode_ = false;
}
void FLWorker::ProcessAfterScalingIn() {
if (worker_node_ = nullptr) {
return;
}
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
while (iteration_state_.load() != IterationState::kCompleted) {
std::this_thread::yield();
}
server_num_ = worker_node_->server_num();
worker_num_ = worker_node_->worker_num();
MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_
<< ". Exit safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
safemode_ = false;
}
} // namespace worker
} // namespace ps
} // namespace mindspore

View File

@ -40,6 +40,13 @@ constexpr uint32_t kTrainEndStepNum = 0;
// The worker has to sleep for a while before the networking is completed.
constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
enum class IterationState {
// This iteration is still in process.
kRunning,
// This iteration is completed and the next iteration is not started yet.
kCompleted
};
namespace worker {
// This class is used for hybrid training mode for now. In later version, parameter server mode will also use this class
// as worker.
@ -53,17 +60,54 @@ class FLWorker {
bool SendToServer(uint32_t server_rank, void *data, size_t size, core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
uint32_t server_num() const;
uint32_t worker_num() const;
uint64_t worker_step_num_per_iteration() const;
private:
FLWorker() = default;
FLWorker()
: server_num_(0),
worker_num_(0),
scheduler_ip_(""),
scheduler_port_(0),
worker_node_(nullptr),
worker_step_num_per_iteration_(1),
iteration_state_(IterationState::kCompleted),
safemode_(false) {}
~FLWorker() = default;
FLWorker(const FLWorker &) = delete;
FLWorker &operator=(const FLWorker &) = delete;
// Initialize the scaler for worker
void InitializeFollowerScaler();
// The handlers for the iteration state events.
void HandleIterationRunningEvent();
void HandleIterationCompletedEvent();
// The barriers before scaling operations.
void ProcessBeforeScalingOut();
void ProcessBeforeScalingIn();
// The handlers after scheduler's scaling operations are done.
void ProcessAfterScalingOut();
void ProcessAfterScalingIn();
uint32_t server_num_;
uint32_t worker_num_;
std::string scheduler_ip_;
uint16_t scheduler_port_;
std::shared_ptr<core::WorkerNode> worker_node_;
// The worker standalone training step number before communicating with server. This used in hybrid training mode for
// now.
uint64_t worker_step_num_per_iteration_;
// The iteration state is either running or completed.
std::atomic<IterationState> iteration_state_;
// The flag that represents whether worker is in safemode.
std::atomic_bool safemode_;
};
} // namespace worker
} // namespace ps

View File

@ -118,6 +118,7 @@ table RequestPushWeight{
table ResponsePushWeight{
retcode:int;
reason:string;
iteration:int;
}
table RequestGetModel{
@ -151,6 +152,7 @@ table RequestPullWeight{
table ResponsePullWeight{
retcode:int;
reason:string;
iteration:int;
feature_map:[FeatureMap];
}

View File

@ -24,6 +24,7 @@ parser.add_argument("--server_num", type=int, default=2)
parser.add_argument("--scheduler_ip", type=str, default="127.0.0.1")
parser.add_argument("--scheduler_port", type=int, default=8113)
parser.add_argument("--fl_server_port", type=int, default=6666)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -34,6 +35,7 @@ if __name__ == "__main__":
scheduler_ip = args.scheduler_ip
scheduler_port = args.scheduler_port
fl_server_port = args.fl_server_port
scheduler_manage_port = args.scheduler_manage_port
cmd_sched = "execute_path=$(pwd) && self_path=$(dirname \"${script_self}\") && rm -rf ${execute_path}/scheduler/ &&"
cmd_sched += "mkdir ${execute_path}/scheduler/ &&"
@ -47,6 +49,7 @@ if __name__ == "__main__":
cmd_sched += " --scheduler_ip=" + scheduler_ip
cmd_sched += " --scheduler_port=" + str(scheduler_port)
cmd_sched += " --fl_server_port=" + str(fl_server_port)
cmd_sched += " --scheduler_manage_port=" + str(scheduler_manage_port)
cmd_sched += " > scheduler.log 2>&1 &"
subprocess.call(['bash', '-c', cmd_sched])