Fix master static check

This commit is contained in:
ZPaC 2021-07-15 16:24:21 +08:00
parent 633e1e49d6
commit a9a0f590e6
48 changed files with 207 additions and 179 deletions

View File

@ -20,8 +20,8 @@ namespace mindspore {
namespace kernel {
namespace ps {
void PServerKernel::Shard(std::vector<size_t> *shape, int axis) {
(*shape)[axis] =
LongToSize(Util::LocalShard(SizeToLong((*shape)[axis]), SizeToLong(rank_id_), SizeToLong(pserver_num_)));
(*shape)[IntToSize(axis)] =
LongToSize(Util::LocalShard(SizeToLong((*shape)[IntToSize(axis)]), SizeToLong(rank_id_), SizeToLong(pserver_num_)));
}
} // namespace ps
} // namespace kernel

View File

@ -350,7 +350,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
}
if (AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
const std::string &param_name = input_node->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
continue;

View File

@ -34,7 +34,7 @@
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#include "debug/data_dump/dump_json_parser.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/util.h"
#include "ps/ps_context.h"
#endif
@ -77,7 +77,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
@ -193,7 +193,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG(INFO) << "Bind input output address";
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
InitPSParamAndOptim(kernel_graph, inputs);
#endif
}

View File

@ -22,7 +22,7 @@
#include "utils/comm_manager.h"
#include "utils/scoped_long_running.h"
#include "pybind_api/ir/tensor_py.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#endif

View File

@ -44,7 +44,7 @@
#include "debug/common.h"
#include "utils/trace_base.h"
#include "frontend/parallel/context.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/constants.h"
#include "ps/util.h"
@ -2483,7 +2483,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIROptimization(graph); }
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::PSContext::instance()->is_worker()) {
return;

View File

@ -157,10 +157,10 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
}
// reconstruct secrets
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list) {
bool CipherReconStruct::ReconstructSecrets(
const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list) {
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
clock_t start_time = clock();
if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) {
@ -285,7 +285,7 @@ void CipherReconStruct::ClearReconstructSecrets() {
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
}
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb,
void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb,
const schema::ResponseCode retcode, const std::string &reason,
const int iteration, const std::string &next_req_time) {
auto fbs_reason = fbb->CreateString(reason);

View File

@ -44,11 +44,11 @@ class CipherReconStruct {
// reconstruct secret mask
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list);
// build response code of reconstruct secret.
void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
void BuildReconstructSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const int iteration, const std::string &next_req_time);
// clear the shared memory.

View File

@ -21,7 +21,7 @@
namespace mindspore {
namespace armour {
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
const string next_req_time) {
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
if (share_secrets_req == nullptr) {
@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
}
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
const std::string &next_req_time) {
MS_LOG(INFO) << "CipherShares::GetSecrets START";
clock_t start_time = clock();
@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
}
void CipherShares::BuildGetSecretsRsp(
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
int rsp_retcode = retcode;
int rsp_iteration = iteration;
@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp(
return;
}
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
void CipherShares::BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
const schema::ResponseCode retcode, const string &reason,
const string &next_req_time, const int iteration) {
auto rsp_reason = share_secrets_resp_builder->CreateString(reason);

View File

@ -43,17 +43,19 @@ class CipherShares {
// handle the client's request of share secrets.
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
const string next_req_time);
// handle the client's request of get secrets.
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
const std::string &next_req_time);
// build response code of share secrets.
void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
void BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
const int iteration);
// build response code of get secrets.
void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
// clear the shared memory.

View File

@ -26,7 +26,7 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
clock_t start_time = clock();
std::vector<float> noise;
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
(void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (noise.size() != cipher_init_->featuremap_) {
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
return false;

View File

@ -114,7 +114,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;

View File

@ -24,6 +24,7 @@
namespace mindspore {
namespace fl {
namespace server {
constexpr uint32_t kDefaultVirtualNodeNum = 32;
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.

View File

@ -104,7 +104,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
}
CountResponse count_rsp;
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
(void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
if (!count_rsp.result()) {
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
@ -138,8 +138,8 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
}
CountReachThresholdResponse count_reach_threshold_rsp;
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
SizeToInt(query_cnt_enough_rsp_msg->size()));
(void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
SizeToInt(query_cnt_enough_rsp_msg->size()));
return count_reach_threshold_rsp.is_enough();
}
}
@ -178,7 +178,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
}
CountRequest report_count_req;
report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = report_count_req.name();
const std::string &id = report_count_req.id();
@ -228,7 +228,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(
}
CountReachThresholdRequest count_reach_threshold_req;
count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = count_reach_threshold_req.name();
std::unique_lock<std::mutex> lock(mutex_[name]);
@ -256,7 +256,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
CounterEvent counter_event;
counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &type = counter_event.type();
const auto &name = counter_event.name();

View File

@ -141,7 +141,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
return get_metadata_rsp;
}
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
(void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
return get_metadata_rsp;
}
}
@ -165,7 +165,7 @@ bool DistributedMetadataStore::ReInitForScaling() {
}
void DistributedMetadataStore::InitHashRing() {
router_ = std::make_shared<ConsistentHashRing>(32);
router_ = std::make_shared<ConsistentHashRing>(kDefaultVirtualNodeNum);
MS_EXCEPTION_IF_NULL(router_);
for (uint32_t i = 0; i < server_num_; i++) {
bool ret = router_->Insert(i);
@ -184,7 +184,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
}
PBMetadataWithName meta_with_name;
meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
const std::string &name = meta_with_name.name();
MS_LOG(INFO) << "Update metadata for " << name;
@ -195,7 +195,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
} else {
update_meta_rsp_msg = "Success";
}
communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
(void)communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
return;
}
@ -206,14 +206,14 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
}
GetMetadataRequest get_metadata_req;
get_metadata_req.ParseFromArray(message->data(), message->len());
(void)get_metadata_req.ParseFromArray(message->data(), message->len());
const std::string &name = get_metadata_req.name();
MS_LOG(INFO) << "Getting metadata for " << name;
std::unique_lock<std::mutex> lock(mutex_[name]);
PBMetadata stored_meta = metadata_[name];
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
(void)communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
return;
}

View File

@ -67,7 +67,7 @@ bool Executor::HandlePush(const std::string &param_name, const UploadData &uploa
// Push operation needs to wait until the pulling process is done.
while (!param_aggr->IsPullingDone()) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
lock.lock();
}
@ -192,7 +192,7 @@ AddressPtr Executor::HandlePull(const std::string &param_name) {
// Pulling must wait until the optimizing process is done.
while (!param_aggr->IsOptimizingDone()) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(5));
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
lock.lock();
}
AddressPtr addr = param_aggr->Pull();
@ -314,7 +314,10 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
param_names_.push_back(param_name);
param_aggrs_[param_name] = param_aggr;
parameter_mutex_[param_name];
param_aggr->Init(cnode, aggregation_count_);
if (!param_aggr->Init(cnode, aggregation_count_)) {
MS_LOG(EXCEPTION) << "Initializing parameter aggregator failed for " << param_name;
return false;
}
MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
}
return true;

View File

@ -33,6 +33,8 @@
namespace mindspore {
namespace fl {
namespace server {
constexpr int kThreadSleepTime = 5;
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
// logics relevant to kernel launching.
class Executor {

View File

@ -74,10 +74,10 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
});
// The time window for one iteration, which will be used in some round kernels.
size_t iteration_time_window =
std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) {
return round->check_timeout() ? total + round->time_window() : total;
});
size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
[](size_t total, const std::shared_ptr<Round> &round) {
return round->check_timeout() ? total + round->time_window() : total;
});
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
return;
@ -162,7 +162,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
return true;
}
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return rounds_; }
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
@ -182,7 +182,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
}
SyncIterationResponse sync_iter_rsp;
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
iteration_num_ = sync_iter_rsp.iteration();
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
<< sync_iter_rsp.iteration();
@ -196,14 +196,14 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::Messa
}
SyncIterationRequest sync_iter_req;
sync_iter_req.ParseFromArray(message->data(), message->len());
(void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
uint32_t rank = sync_iter_req.rank();
MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank;
SyncIterationResponse sync_iter_rsp;
sync_iter_rsp.set_iteration(iteration_num_);
std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
(void)communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
}
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
@ -238,11 +238,11 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps
NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
notify_leader_to_next_iter_rsp.set_result("success");
communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
(void)communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &rank = notify_leader_to_next_iter_req.rank();
const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
@ -296,7 +296,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
}
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
});
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
return true;
}
@ -306,15 +306,15 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::
}
PrepareForNextIterRequest prepare_next_iter_req;
prepare_next_iter_req.ParseFromArray(message->data(), message->len());
(void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &reason = prepare_next_iter_req.reason();
MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
PrepareForNextIter();
PrepareForNextIterResponse prepare_next_iter_rsp;
prepare_next_iter_rsp.set_result("success");
communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
prepare_next_iter_rsp.SerializeAsString().size(), message);
(void)communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
prepare_next_iter_rsp.SerializeAsString().size(), message);
}
void Iteration::PrepareForNextIter() {
@ -347,11 +347,11 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::Mess
MoveToNextIterResponse proceed_to_next_iter_rsp;
proceed_to_next_iter_rsp.set_result("success");
communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
(void)communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
MoveToNextIterRequest proceed_to_next_iter_req;
proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
const auto &reason = proceed_to_next_iter_req.reason();
@ -370,12 +370,12 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
if (is_iteration_valid) {
// Store the model which is successfully aggregated for this iteration.
const auto &model = Executor::GetInstance().GetModel();
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
} else {
// Store last iteration's model because this iteration is considered as invalid.
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
}
@ -405,7 +405,7 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
}
EndLastIterRequest end_last_iter_req;
end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
(void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
const auto &last_iter_num = end_last_iter_req.last_iter_num();
// If the iteration number is not matched, return error.
if (last_iter_num != iteration_num_) {
@ -413,8 +413,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num);
EndLastIterResponse end_last_iter_rsp;
end_last_iter_rsp.set_result(reason);
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message);
(void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message);
return;
}
@ -422,8 +422,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
EndLastIterResponse end_last_iter_rsp;
end_last_iter_rsp.set_result("success");
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message);
(void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
end_last_iter_rsp.SerializeAsString().size(), message);
}
void Iteration::EndLastIter() {

View File

@ -79,7 +79,7 @@ class Iteration {
// The server number after scaling is required in some rounds.
bool ReInitForScaling(uint32_t server_num, uint32_t server_rank);
const std::vector<std::shared_ptr<Round>> &rounds();
const std::vector<std::shared_ptr<Round>> &rounds() const;
bool is_last_iteration_valid() const;

View File

@ -70,10 +70,9 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::map<std::string, AddressPtr> feature_maps;
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
size_t get_model_iter = IntToSize(get_model_req->iteration());
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first;
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +

View File

@ -63,10 +63,11 @@ bool PullWeightKernel::Reset() {
return true;
}
void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req) {
void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPullWeight *pull_weight_req) {
std::map<std::string, AddressPtr> feature_maps = {};
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
size_t pull_weight_iter = IntToSize(pull_weight_req->iteration());
// The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry.
if (pull_weight_iter != current_iter) {
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
@ -110,7 +111,7 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::
return;
}
void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, size_t iteration,
const std::map<std::string, AddressPtr> &feature_maps) {
auto fbs_reason = fbb->CreateString(reason);
@ -127,7 +128,7 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
rsp_pull_weight_builder.add_retcode(retcode);
rsp_pull_weight_builder.add_reason(fbs_reason);
rsp_pull_weight_builder.add_iteration(iteration);
rsp_pull_weight_builder.add_iteration(SizeToInt(iteration));
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
auto rsp_pull_weight = rsp_pull_weight_builder.Finish();
fbb->Finish(rsp_pull_weight);

View File

@ -42,9 +42,10 @@ class PullWeightKernel : public RoundKernel {
bool Reset() override;
private:
void PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req);
void BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
size_t iteration, const std::map<std::string, AddressPtr> &feature_maps);
void PullWeight(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPullWeight *pull_weight_req);
void BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, size_t iteration,
const std::map<std::string, AddressPtr> &feature_maps);
Executor *executor_;

View File

@ -67,12 +67,12 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH
return;
}
ResultCode PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb,
ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestPushWeight *push_weight_req) {
if (fbb == nullptr || push_weight_req == nullptr) {
return ResultCode::kSuccessAndReturn;
}
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
size_t iteration = IntToSize(push_weight_req->iteration());
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
if (iteration != current_iter) {
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
@ -123,13 +123,13 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R
return upload_feature_map;
}
void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, size_t iteration) {
auto fbs_reason = fbb->CreateString(reason);
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
rsp_push_weight_builder.add_retcode(retcode);
rsp_push_weight_builder.add_reason(fbs_reason);
rsp_push_weight_builder.add_iteration(iteration);
rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
auto rsp_push_weight = rsp_push_weight_builder.Finish();
fbb->Finish(rsp_push_weight);
return;

View File

@ -42,10 +42,10 @@ class PushWeightKernel : public RoundKernel {
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
ResultCode PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
ResultCode PushWeight(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPushWeight *push_weight_req);
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
size_t iteration);
void BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, size_t iteration);
Executor *executor_;
uint32_t local_rank_;

View File

@ -34,7 +34,7 @@ RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), e
// Detect whether there's any data needs to be released every 100 milliseconds.
if (heap_data_to_release_.empty()) {
release_lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::this_thread::sleep_for(std::chrono::milliseconds(kReleaseDuration));
continue;
}
@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() {
}
}
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; }
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; }
void RoundKernel::StopTimer() const {
if (stop_timer_cb_) {

View File

@ -38,6 +38,7 @@ namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
constexpr uint64_t kReleaseDuration = 100;
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
// kernels to represent the process. They receive and parse messages from the server communication module. After
// handling these messages, round kernels allocate response data and send it back.

View File

@ -118,7 +118,7 @@ bool StartFLJobKernel::Reset() {
}
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
// The first startFLJob request means a new iteration starts running.
Iteration::GetInstance().SetIterationRunning();
@ -220,9 +220,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_server_mode(fbs_server_mode);
fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size());
fl_plan_builder.add_iterations(SizeToInt(ps::PSContext::instance()->fl_iteration_num()));
fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num()));
fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size()));
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
#ifdef ENABLE_ARMOUR
fl_plan_builder.add_cipher(cipher_public_params);

View File

@ -90,7 +90,7 @@ bool UpdateModelKernel::Reset() {
return true;
}
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
while (!executor_->IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
@ -120,7 +120,7 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
const std::shared_ptr<FBBuilder> &fbb) {
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
size_t iteration = static_cast<size_t>(update_model_req->iteration());
size_t iteration = IntToSize(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +

View File

@ -281,16 +281,16 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
KernelParams aggr_params = {};
const std::vector<std::string> &input_names = aggr_kernel->input_names();
std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
(void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
[&](const std::string &name) { return memory_register->addresses()[name]; });
(void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
[&](const std::string &name) { return memory_register->addresses()[name]; });
const std::vector<std::string> &output_names = aggr_kernel->output_names();
std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
(void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
@ -304,16 +304,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
KernelParams optimizer_params = {};
const std::vector<std::string> &input_names = optimizer_kernel->input_names();
std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
(void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
[&](const std::string &name) { return memory_register->addresses()[name]; });
(void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
[&](const std::string &name) { return memory_register->addresses()[name]; });
const std::vector<std::string> &output_names = optimizer_kernel->output_names();
std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
(void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
return true;

View File

@ -34,8 +34,8 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
threshold_count_(threshold_count),
server_num_as_threshold_(server_num_as_threshold) {}
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb) {
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
const FinishIterCb &finish_iteration_cb) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
@ -50,7 +50,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
};
// Callback for finalizing the server. This can only be called once.
finalize_cb_ = [&](void) -> void { communicator_->Stop(); };
finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); };
if (check_timeout_) {
iter_timer_ = std::make_shared<IterationTimer>();
@ -116,7 +116,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
if (Server::GetInstance().IsSafeMode()) {
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
std::string reason = "The cluster is in safemode.";
communicator_->SendResponse(reason.c_str(), reason.size(), message);
(void)communicator_->SendResponse(reason.c_str(), reason.size(), message);
return;
}
@ -128,10 +128,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
if (output->size == 0) {
std::string reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
communicator_->SendResponse(reason.c_str(), reason.size(), message);
(void)communicator_->SendResponse(reason.c_str(), reason.size(), message);
return;
}
communicator_->SendResponse(output->addr, output->size, message);
(void)communicator_->SendResponse(output->addr, output->size, message);
kernel_->Release(output);
// Must send response back no matter what value Launch method returns.
@ -142,7 +142,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
return;
}
void Round::Reset() { kernel_->Reset(); }
void Round::Reset() { (void)kernel_->Reset(); }
const std::string &Round::name() const { return name_; }

View File

@ -37,8 +37,8 @@ class Round {
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
~Round() = default;
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb);
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
const FinishIterCb &finish_iteration_cb);
// Reinitialize count service and round kernel of this round after scaling operations are done.
bool ReInitForScaling(uint32_t server_num);

View File

@ -102,7 +102,7 @@ void Server::CancelSafeMode() {
safemode_ = false;
}
bool Server::IsSafeMode() { return safemode_.load(); }
bool Server::IsSafeMode() const { return safemode_.load(); }
void Server::InitServerContext() {
ps::PSContext::instance()->GenerateResetterRound();
@ -121,7 +121,7 @@ void Server::InitServerContext() {
void Server::InitCluster() {
server_node_ = std::make_shared<ps::core::ServerNode>();
MS_EXCEPTION_IF_NULL(server_node_);
task_executor_ = std::make_shared<ps::core::TaskExecutor>(32);
task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize);
MS_EXCEPTION_IF_NULL(task_executor_);
if (!InitCommunicatorWithServer()) {
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
@ -235,9 +235,9 @@ void Server::InitCipher() {
#ifdef ENABLE_ARMOUR
cipher_init_ = &armour::CipherInit::GetInstance();
int cipher_t = cipher_reconstruct_secrets_down_cnt_;
int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
int cipher_g = 1;
const int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
float dp_eps = ps::PSContext::instance()->dp_eps();
float dp_delta = ps::PSContext::instance()->dp_delta();
@ -304,8 +304,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
(void)communicator_with_server_->Stop();
});
communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
@ -314,8 +314,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
"network building phase.";
safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
(void)communicator_with_server_->Stop();
});
}
@ -363,7 +363,10 @@ void Server::StartCommunicator() {
}
MS_LOG(INFO) << "Start communicator with server.";
communicator_with_server_->Start();
if (!communicator_with_server_->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
return;
}
DistributedMetadataStore::GetInstance().Initialize(server_node_);
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
@ -371,7 +374,11 @@ void Server::StartCommunicator() {
MS_LOG(INFO) << "Start communicator with worker.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
if (!communicator->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
}
});
}
void Server::ProcessBeforeScalingOut() {
@ -405,7 +412,7 @@ void Server::ProcessAfterScalingOut() {
if (!Executor::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "Executor reinitializing failed.";
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
safemode_ = false;
}
@ -418,7 +425,7 @@ void Server::ProcessAfterScalingIn() {
if (server_node_->rank_id() == UINT32_MAX) {
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
communicator_with_server_->Stop();
return;
}
@ -439,7 +446,7 @@ void Server::ProcessAfterScalingIn() {
if (!Executor::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "Executor reinitializing failed.";
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
safemode_ = false;
}
} // namespace server

View File

@ -33,6 +33,9 @@
namespace mindspore {
namespace fl {
namespace server {
// The sleeping time of the server thread before the networking is completed.
constexpr uint32_t kServerSleepTimeForNetworking = 1000;
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
class Server {
public:
@ -51,7 +54,7 @@ class Server {
void SwitchToSafeMode();
void CancelSafeMode();
bool IsSafeMode();
bool IsSafeMode() const;
private:
Server()
@ -162,8 +165,6 @@ class Server {
uint32_t server_num_;
uint32_t worker_num_;
uint16_t fl_server_port_;
size_t start_fl_job_cnt_;
size_t update_model_cnt_;
size_t cipher_initial_client_cnt_;
size_t cipher_exchange_secrets_cnt_;
size_t cipher_share_secrets_cnt_;
@ -171,9 +172,6 @@ class Server {
size_t cipher_reconstruct_secrets_up_cnt_;
size_t cipher_reconstruct_secrets_down_cnt_;
uint64_t cipher_time_window_;
float percent_for_update_model_;
float percent_for_get_model_;
};
} // namespace server
} // namespace fl

View File

@ -70,8 +70,14 @@ void FLWorker::Run() {
void FLWorker::Finalize() {
MS_EXCEPTION_IF_NULL(worker_node_);
worker_node_->Finish();
worker_node_->Stop();
if (!worker_node_->Finish()) {
MS_LOG(ERROR) << "Worker node finishing failed.";
return;
}
if (!worker_node_->Stop()) {
MS_LOG(ERROR) << "Worker node stopping failed.";
return;
}
}
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
@ -201,8 +207,8 @@ void FLWorker::ProcessAfterScalingOut() {
}
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
server_num_ = worker_node_->server_num();
worker_num_ = worker_node_->worker_num();
server_num_ = IntToUint(worker_node_->server_num());
worker_num_ = IntToUint(worker_node_->worker_num());
MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
<< server_num_ << ". Exit safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
@ -215,8 +221,8 @@ void FLWorker::ProcessAfterScalingIn() {
}
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
server_num_ = worker_node_->server_num();
worker_num_ = worker_node_->worker_num();
server_num_ = IntToUint(worker_node_->server_num());
worker_num_ = IntToUint(worker_node_->worker_num());
MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_
<< ". Exit safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));

View File

@ -25,7 +25,7 @@
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/context.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#include "utils/ms_context.h"
#endif
@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() {
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
dynamic_shape_indices_ = true;
}
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
@ -637,7 +637,7 @@ Status GatherPInfo::InferBias() {
rank = rank % (params_strategy[0] * params_strategy[1]);
}
}
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
return SUCCESS;

View File

@ -28,7 +28,7 @@
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
@ -119,7 +119,7 @@ std::vector<StrategyPtr> UniqueInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
@ -156,7 +156,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
#endif
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
auto inputs = cnode->inputs();
if (inputs.empty()) {

View File

@ -47,7 +47,7 @@
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#include "mindspore/core/utils/parallel_node_check.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/util.h"
#include "ps/ps_context.h"
#endif

View File

@ -44,7 +44,7 @@
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "frontend/optimizer/py_pass_manager.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/parameter_server.h"
#include "ps/scheduler.h"
#include "ps/worker.h"
@ -478,7 +478,7 @@ bool OptInlineAction(const ResourcePtr &res) {
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
bool VmOptimizeAction(const ResourcePtr &res) {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_ps_mode()) {
kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps});
}
@ -633,8 +633,8 @@ bool ExecuteAction(const ResourcePtr &res) {
return true;
}
#if (ENABLE_CPU && !_WIN32)
bool StartPSWorkerAction(const ResourcePtr &res) {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
bool StartPSWorkerAction(const ResourcePtr &) {
ps::Worker::GetInstance().Run();
return true;
}
@ -695,7 +695,7 @@ bool StartServerAction(const ResourcePtr &res) {
return true;
}
bool StartPSSchedulerAction(const ResourcePtr &res) {
bool StartPSSchedulerAction(const ResourcePtr &) {
ps::Scheduler::GetInstance().Run();
return true;
}
@ -861,7 +861,7 @@ std::vector<ActionItem> VmPipeline() {
actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_worker()) {
std::string server_mode = ps::PSContext::instance()->server_mode();
if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
@ -889,7 +889,7 @@ std::vector<ActionItem> BackendPipeline() {
return actions;
}
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
std::vector<ActionItem> ServerPipeline() {
auto actions = CommonPipeline();
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));

View File

@ -34,7 +34,7 @@
#else
#include "runtime/device/gpu/distribution/collective_fake_init.h"
#endif
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/util.h"
#endif
#include "ps/ps_context.h"

View File

@ -47,7 +47,7 @@
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/util.h"
#include "ps/ps_context.h"
#endif
@ -211,7 +211,7 @@ namespace {
bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
bool parallel_mode() {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
return false;
}
@ -556,7 +556,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
}
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_ps_mode()) {
return true;
}

View File

@ -148,7 +148,7 @@ std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::str
}
std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip,
std::int16_t scheduler_port, uint32_t worker_num,
uint16_t scheduler_port, uint32_t worker_num,
uint32_t server_num,
const std::shared_ptr<TaskExecutor> &task_executor) {
std::lock_guard<std::mutex> lock(communicator_mutex_);

View File

@ -61,7 +61,7 @@ class ServerNode : public AbstractNode {
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
const std::shared_ptr<TaskExecutor> &task_executor);
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port,
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, uint16_t scheduler_port,
uint32_t worker_num, uint32_t server_num,
const std::shared_ptr<TaskExecutor> &task_executor);

View File

@ -18,7 +18,7 @@
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/kernel.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif
@ -63,7 +63,7 @@ void PSContext::SetPSEnable(bool enabled) {
}
bool PSContext::is_ps_mode() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return true;
}
return ps_enabled_;
@ -74,7 +74,7 @@ void PSContext::Reset() {
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps_cache_instance.Finalize();
set_cache_enable(false);
@ -83,7 +83,7 @@ void PSContext::Reset() {
}
std::string PSContext::ms_role() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_;
}
if (is_worker_) {
@ -98,21 +98,21 @@ std::string PSContext::ms_role() const {
}
bool PSContext::is_worker() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_ == kEnvRoleOfWorker;
}
return is_worker_;
}
bool PSContext::is_server() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_ == kEnvRoleOfServer;
}
return is_pserver_;
}
bool PSContext::is_scheduler() const {
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
return role_ == kEnvRoleOfScheduler;
}
return is_sched_;
@ -130,44 +130,44 @@ uint32_t PSContext::ps_rank_id() const { return rank_id_; }
void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
size_t vocab_size) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
#endif
}
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
size_t cache_vocab_size, size_t embedding_size) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
#endif
}
void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
#endif
}
void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
#endif
}
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
#endif
}
void PSContext::set_cache_enable(bool cache_enable) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
#endif
}
void PSContext::set_rank_id(uint32_t rank_id) const {
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.set_rank_id(rank_id);
#endif
}
@ -358,7 +358,7 @@ void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
if (reconstruct_secrets_threshold <= 0) {
if (reconstruct_secrets_threshold == 0) {
MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
return;
}

View File

@ -136,7 +136,8 @@ bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
return true;
}
void Util::DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) {
void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
const std::string &fused_cnode_name) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());

View File

@ -56,7 +56,8 @@ class Util {
static bool FuseServerCommOps(const pipeline::ResourcePtr &res);
private:
static void DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name);
static void DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
const std::string &fused_cnode_name);
static kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list);
static std::unordered_map<std::string, int64_t> optimizer_to_ids;

View File

@ -32,7 +32,7 @@
#include "utils/utils.h"
#include "frontend/parallel/context.h"
#include "debug/env_config_parser.h"
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
@ -333,7 +333,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
add_need_alloc_nodes(input_node);
}
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
bool ps_cache_check = false;
#endif
for (auto &item : need_alloc_nodes) {
@ -346,7 +346,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
continue;
}
DeviceAddressPtr device_address = nullptr;
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
const std::string &param_name = item->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(INFO) << "Parameter(" << param_name << ")"
@ -1087,7 +1087,7 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
}
}
#if (ENABLE_CPU && !_WIN32)
#if ((defined ENABLE_CPU) && (!defined _WIN32))
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
AnfNodePtr *const first_cache_input_index,
size_t *const first_cache_size) {

View File

@ -817,6 +817,7 @@ def reset_ps_context():
"""
_reset_ps_context()
def set_fl_context(**kwargs):
"""
Set federated learning training mode context.

View File

@ -726,6 +726,7 @@ class Pull(PrimitiveWithInfer):
def infer_dtype(self, key_dtype, weight_dtype):
return mstype.float32
class PullWeight(PrimitiveWithInfer):
"""
Pull weight by its names from server.
@ -751,6 +752,7 @@ class PullWeight(PrimitiveWithInfer):
def infer_dtype(self, weight, name, index):
return mstype.float32
class PushWeight(PrimitiveWithInfer):
"""
Upload weight by its names to server.
@ -776,6 +778,7 @@ class PushWeight(PrimitiveWithInfer):
def infer_dtype(self, weight, ps_key, index):
return mstype.float32
class identity(Primitive):
"""
Makes a identify primitive, used for pynative mode.

View File

@ -31,6 +31,7 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"]
def ps_context():
"""
Get the global _ps_context, if it is not created, create a new one.
@ -226,6 +227,7 @@ def _set_cache_enable(cache_enable):
def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id)
def _check_value(key, value):
"""
Validate the value for parameter server context keys.