Fix master static check
This commit is contained in:
parent
633e1e49d6
commit
a9a0f590e6
|
@ -20,8 +20,8 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace ps {
|
||||
void PServerKernel::Shard(std::vector<size_t> *shape, int axis) {
|
||||
(*shape)[axis] =
|
||||
LongToSize(Util::LocalShard(SizeToLong((*shape)[axis]), SizeToLong(rank_id_), SizeToLong(pserver_num_)));
|
||||
(*shape)[IntToSize(axis)] =
|
||||
LongToSize(Util::LocalShard(SizeToLong((*shape)[IntToSize(axis)]), SizeToLong(rank_id_), SizeToLong(pserver_num_)));
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace kernel
|
||||
|
|
|
@ -350,7 +350,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
if (AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
const std::string ¶m_name = input_node->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
continue;
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
@ -77,7 +77,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
|
|||
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
|
||||
|
@ -193,7 +193,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_LOG(INFO) << "Bind input output address";
|
||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
InitPSParamAndOptim(kernel_graph, inputs);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "utils/comm_manager.h"
|
||||
#include "utils/scoped_long_running.h"
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
#include "debug/common.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -2483,7 +2483,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
|
||||
void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIROptimization(graph); }
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::PSContext::instance()->is_worker()) {
|
||||
return;
|
||||
|
|
|
@ -157,10 +157,10 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
|||
}
|
||||
|
||||
// reconstruct secrets
|
||||
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||
const std::vector<std::string> &client_list) {
|
||||
bool CipherReconStruct::ReconstructSecrets(
|
||||
const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
|
||||
const std::vector<std::string> &client_list) {
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
||||
clock_t start_time = clock();
|
||||
if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) {
|
||||
|
@ -285,7 +285,7 @@ void CipherReconStruct::ClearReconstructSecrets() {
|
|||
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
|
||||
}
|
||||
|
||||
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb,
|
||||
void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
const int iteration, const std::string &next_req_time) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
|
|
|
@ -44,11 +44,11 @@ class CipherReconStruct {
|
|||
// reconstruct secret mask
|
||||
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
|
||||
const std::vector<std::string> &client_list);
|
||||
|
||||
// build response code of reconstruct secret.
|
||||
void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
void BuildReconstructSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const int iteration, const std::string &next_req_time);
|
||||
|
||||
// clear the shared memory.
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace armour {
|
||||
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
||||
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
|
||||
const string next_req_time) {
|
||||
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
||||
if (share_secrets_req == nullptr) {
|
||||
|
@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
}
|
||||
|
||||
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
|
||||
const std::string &next_req_time) {
|
||||
MS_LOG(INFO) << "CipherShares::GetSecrets START";
|
||||
clock_t start_time = clock();
|
||||
|
@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
|||
}
|
||||
|
||||
void CipherShares::BuildGetSecretsRsp(
|
||||
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
|
||||
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
|
||||
int rsp_retcode = retcode;
|
||||
int rsp_iteration = iteration;
|
||||
|
@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp(
|
|||
return;
|
||||
}
|
||||
|
||||
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||
void CipherShares::BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
|
||||
const schema::ResponseCode retcode, const string &reason,
|
||||
const string &next_req_time, const int iteration) {
|
||||
auto rsp_reason = share_secrets_resp_builder->CreateString(reason);
|
||||
|
|
|
@ -43,17 +43,19 @@ class CipherShares {
|
|||
|
||||
// handle the client's request of share secrets.
|
||||
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
||||
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
|
||||
const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
|
||||
const string next_req_time);
|
||||
// handle the client's request of get secrets.
|
||||
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
|
||||
const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
|
||||
const std::string &next_req_time);
|
||||
|
||||
// build response code of share secrets.
|
||||
void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||
void BuildShareSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &share_secrets_resp_builder,
|
||||
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
|
||||
const int iteration);
|
||||
// build response code of get secrets.
|
||||
void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
|
||||
void BuildGetSecretsRsp(const std::shared_ptr<fl::server::FBBuilder> &get_secrets_resp_builder,
|
||||
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
|
||||
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
|
||||
// clear the shared memory.
|
||||
|
|
|
@ -26,7 +26,7 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
|
|||
clock_t start_time = clock();
|
||||
std::vector<float> noise;
|
||||
|
||||
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||
(void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||
if (noise.size() != cipher_init_->featuremap_) {
|
||||
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
||||
return false;
|
||||
|
|
|
@ -114,7 +114,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
constexpr uint32_t kDefaultVirtualNodeNum = 32;
|
||||
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
|
||||
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
|||
}
|
||||
|
||||
CountResponse count_rsp;
|
||||
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
|
||||
(void)count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
|
||||
if (!count_rsp.result()) {
|
||||
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
|
||||
if (reason != nullptr && count_rsp.reason().find(kNetworkError) != std::string::npos) {
|
||||
|
@ -138,8 +138,8 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
|||
}
|
||||
|
||||
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
|
||||
SizeToInt(query_cnt_enough_rsp_msg->size()));
|
||||
(void)count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
|
||||
SizeToInt(query_cnt_enough_rsp_msg->size()));
|
||||
return count_reach_threshold_rsp.is_enough();
|
||||
}
|
||||
}
|
||||
|
@ -178,7 +178,7 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core:
|
|||
}
|
||||
|
||||
CountRequest report_count_req;
|
||||
report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = report_count_req.name();
|
||||
const std::string &id = report_count_req.id();
|
||||
|
||||
|
@ -228,7 +228,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(
|
|||
}
|
||||
|
||||
CountReachThresholdRequest count_reach_threshold_req;
|
||||
count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = count_reach_threshold_req.name();
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
|
@ -256,7 +256,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core:
|
|||
communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
|
||||
|
||||
CounterEvent counter_event;
|
||||
counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &type = counter_event.type();
|
||||
const auto &name = counter_event.name();
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
|||
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
|
||||
return get_metadata_rsp;
|
||||
}
|
||||
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
|
||||
(void)get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
|
||||
return get_metadata_rsp;
|
||||
}
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ bool DistributedMetadataStore::ReInitForScaling() {
|
|||
}
|
||||
|
||||
void DistributedMetadataStore::InitHashRing() {
|
||||
router_ = std::make_shared<ConsistentHashRing>(32);
|
||||
router_ = std::make_shared<ConsistentHashRing>(kDefaultVirtualNodeNum);
|
||||
MS_EXCEPTION_IF_NULL(router_);
|
||||
for (uint32_t i = 0; i < server_num_; i++) {
|
||||
bool ret = router_->Insert(i);
|
||||
|
@ -184,7 +184,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
|||
}
|
||||
|
||||
PBMetadataWithName meta_with_name;
|
||||
meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = meta_with_name.name();
|
||||
MS_LOG(INFO) << "Update metadata for " << name;
|
||||
|
||||
|
@ -195,7 +195,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
|||
} else {
|
||||
update_meta_rsp_msg = "Success";
|
||||
}
|
||||
communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
|
||||
(void)communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -206,14 +206,14 @@ void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps
|
|||
}
|
||||
|
||||
GetMetadataRequest get_metadata_req;
|
||||
get_metadata_req.ParseFromArray(message->data(), message->len());
|
||||
(void)get_metadata_req.ParseFromArray(message->data(), message->len());
|
||||
const std::string &name = get_metadata_req.name();
|
||||
MS_LOG(INFO) << "Getting metadata for " << name;
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
PBMetadata stored_meta = metadata_[name];
|
||||
std::string getting_meta_rsp_msg = stored_meta.SerializeAsString();
|
||||
communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
|
||||
(void)communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ bool Executor::HandlePush(const std::string ¶m_name, const UploadData &uploa
|
|||
// Push operation needs to wait until the pulling process is done.
|
||||
while (!param_aggr->IsPullingDone()) {
|
||||
lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
|
||||
lock.lock();
|
||||
}
|
||||
|
||||
|
@ -192,7 +192,7 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
|||
// Pulling must wait until the optimizing process is done.
|
||||
while (!param_aggr->IsOptimizingDone()) {
|
||||
lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
|
||||
lock.lock();
|
||||
}
|
||||
AddressPtr addr = param_aggr->Pull();
|
||||
|
@ -314,7 +314,10 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
|||
param_names_.push_back(param_name);
|
||||
param_aggrs_[param_name] = param_aggr;
|
||||
parameter_mutex_[param_name];
|
||||
param_aggr->Init(cnode, aggregation_count_);
|
||||
if (!param_aggr->Init(cnode, aggregation_count_)) {
|
||||
MS_LOG(EXCEPTION) << "Initializing parameter aggregator failed for " << param_name;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success.";
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -33,6 +33,8 @@
|
|||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
constexpr int kThreadSleepTime = 5;
|
||||
|
||||
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
|
||||
// logics relevant to kernel launching.
|
||||
class Executor {
|
||||
|
|
|
@ -74,10 +74,10 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
|
|||
});
|
||||
|
||||
// The time window for one iteration, which will be used in some round kernels.
|
||||
size_t iteration_time_window =
|
||||
std::accumulate(rounds_.begin(), rounds_.end(), 0, [](size_t total, const std::shared_ptr<Round> &round) {
|
||||
return round->check_timeout() ? total + round->time_window() : total;
|
||||
});
|
||||
size_t iteration_time_window = std::accumulate(rounds_.begin(), rounds_.end(), IntToSize(0),
|
||||
[](size_t total, const std::shared_ptr<Round> &round) {
|
||||
return round->check_timeout() ? total + round->time_window() : total;
|
||||
});
|
||||
LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window);
|
||||
MS_LOG(INFO) << "Time window for one iteration is " << iteration_time_window;
|
||||
return;
|
||||
|
@ -162,7 +162,7 @@ bool Iteration::ReInitForScaling(uint32_t server_num, uint32_t server_rank) {
|
|||
return true;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; }
|
||||
const std::vector<std::shared_ptr<Round>> &Iteration::rounds() const { return rounds_; }
|
||||
|
||||
bool Iteration::is_last_iteration_valid() const { return is_last_iteration_valid_; }
|
||||
|
||||
|
@ -182,7 +182,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
}
|
||||
|
||||
SyncIterationResponse sync_iter_rsp;
|
||||
sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), sync_iter_rsp_msg->size());
|
||||
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
|
||||
iteration_num_ = sync_iter_rsp.iteration();
|
||||
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
|
||||
<< sync_iter_rsp.iteration();
|
||||
|
@ -196,14 +196,14 @@ void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::Messa
|
|||
}
|
||||
|
||||
SyncIterationRequest sync_iter_req;
|
||||
sync_iter_req.ParseFromArray(message->data(), message->len());
|
||||
(void)sync_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
uint32_t rank = sync_iter_req.rank();
|
||||
MS_LOG(INFO) << "Synchronizing iteration request from rank " << rank;
|
||||
|
||||
SyncIterationResponse sync_iter_rsp;
|
||||
sync_iter_rsp.set_iteration(iteration_num_);
|
||||
std::string sync_iter_rsp_msg = sync_iter_rsp.SerializeAsString();
|
||||
communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
|
||||
(void)communicator_->SendResponse(sync_iter_rsp_msg.data(), sync_iter_rsp_msg.size(), message);
|
||||
}
|
||||
|
||||
bool Iteration::IsMoveToNextIterRequestReentrant(uint64_t iteration_num) {
|
||||
|
@ -238,11 +238,11 @@ void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps
|
|||
|
||||
NotifyLeaderMoveToNextIterResponse notify_leader_to_next_iter_rsp;
|
||||
notify_leader_to_next_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
|
||||
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
(void)communicator_->SendResponse(notify_leader_to_next_iter_rsp.SerializeAsString().data(),
|
||||
notify_leader_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
|
||||
NotifyLeaderMoveToNextIterRequest notify_leader_to_next_iter_req;
|
||||
notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)notify_leader_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &rank = notify_leader_to_next_iter_req.rank();
|
||||
const auto &is_last_iter_valid = notify_leader_to_next_iter_req.is_last_iter_valid();
|
||||
const auto &iter_num = notify_leader_to_next_iter_req.iter_num();
|
||||
|
@ -296,7 +296,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
|
|||
}
|
||||
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
|
||||
});
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -306,15 +306,15 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::
|
|||
}
|
||||
|
||||
PrepareForNextIterRequest prepare_next_iter_req;
|
||||
prepare_next_iter_req.ParseFromArray(message->data(), message->len());
|
||||
(void)prepare_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &reason = prepare_next_iter_req.reason();
|
||||
MS_LOG(INFO) << "Prepare next iteration for this rank " << server_node_->rank_id() << ", reason: " << reason;
|
||||
PrepareForNextIter();
|
||||
|
||||
PrepareForNextIterResponse prepare_next_iter_rsp;
|
||||
prepare_next_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
|
||||
prepare_next_iter_rsp.SerializeAsString().size(), message);
|
||||
(void)communicator_->SendResponse(prepare_next_iter_rsp.SerializeAsString().data(),
|
||||
prepare_next_iter_rsp.SerializeAsString().size(), message);
|
||||
}
|
||||
|
||||
void Iteration::PrepareForNextIter() {
|
||||
|
@ -347,11 +347,11 @@ void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::Mess
|
|||
|
||||
MoveToNextIterResponse proceed_to_next_iter_rsp;
|
||||
proceed_to_next_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
|
||||
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
(void)communicator_->SendResponse(proceed_to_next_iter_rsp.SerializeAsString().data(),
|
||||
proceed_to_next_iter_rsp.SerializeAsString().size(), message);
|
||||
|
||||
MoveToNextIterRequest proceed_to_next_iter_req;
|
||||
proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)proceed_to_next_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &is_last_iter_valid = proceed_to_next_iter_req.is_last_iter_valid();
|
||||
const auto &last_iter_num = proceed_to_next_iter_req.last_iter_num();
|
||||
const auto &reason = proceed_to_next_iter_req.reason();
|
||||
|
@ -370,12 +370,12 @@ void Iteration::Next(bool is_iteration_valid, const std::string &reason) {
|
|||
if (is_iteration_valid) {
|
||||
// Store the model which is successfully aggregated for this iteration.
|
||||
const auto &model = Executor::GetInstance().GetModel();
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " is successfully finished.";
|
||||
} else {
|
||||
// Store last iteration's model because this iteration is considered as invalid.
|
||||
const auto &model = ModelStore::GetInstance().GetModelByIterNum(iteration_num_ - 1);
|
||||
ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
(void)ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model);
|
||||
MS_LOG(WARNING) << "Iteration " << iteration_num_ << " is invalid. Reason: " << reason;
|
||||
}
|
||||
|
||||
|
@ -405,7 +405,7 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
|
|||
}
|
||||
|
||||
EndLastIterRequest end_last_iter_req;
|
||||
end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
(void)end_last_iter_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &last_iter_num = end_last_iter_req.last_iter_num();
|
||||
// If the iteration number is not matched, return error.
|
||||
if (last_iter_num != iteration_num_) {
|
||||
|
@ -413,8 +413,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
|
|||
std::to_string(iteration_num_) + ", iteration to be ended is " + std::to_string(last_iter_num);
|
||||
EndLastIterResponse end_last_iter_rsp;
|
||||
end_last_iter_rsp.set_result(reason);
|
||||
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message);
|
||||
(void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -422,8 +422,8 @@ void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::Message
|
|||
|
||||
EndLastIterResponse end_last_iter_rsp;
|
||||
end_last_iter_rsp.set_result("success");
|
||||
communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message);
|
||||
(void)communicator_->SendResponse(end_last_iter_rsp.SerializeAsString().data(),
|
||||
end_last_iter_rsp.SerializeAsString().size(), message);
|
||||
}
|
||||
|
||||
void Iteration::EndLastIter() {
|
||||
|
|
|
@ -79,7 +79,7 @@ class Iteration {
|
|||
// The server number after scaling is required in some rounds.
|
||||
bool ReInitForScaling(uint32_t server_num, uint32_t server_rank);
|
||||
|
||||
const std::vector<std::shared_ptr<Round>> &rounds();
|
||||
const std::vector<std::shared_ptr<Round>> &rounds() const;
|
||||
|
||||
bool is_last_iteration_valid() const;
|
||||
|
||||
|
|
|
@ -70,10 +70,9 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||
std::map<std::string, AddressPtr> feature_maps;
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
|
||||
size_t get_model_iter = IntToSize(get_model_req->iteration());
|
||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||
|
||||
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter) || current_iter == get_model_iter - 1) {
|
||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
|
||||
|
|
|
@ -63,10 +63,11 @@ bool PullWeightKernel::Reset() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req) {
|
||||
void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestPullWeight *pull_weight_req) {
|
||||
std::map<std::string, AddressPtr> feature_maps = {};
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
|
||||
size_t pull_weight_iter = IntToSize(pull_weight_req->iteration());
|
||||
// The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry.
|
||||
if (pull_weight_iter != current_iter) {
|
||||
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
|
||||
|
@ -110,7 +111,7 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::
|
|||
return;
|
||||
}
|
||||
|
||||
void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, size_t iteration,
|
||||
const std::map<std::string, AddressPtr> &feature_maps) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
|
@ -127,7 +128,7 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
|
|||
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
|
||||
rsp_pull_weight_builder.add_retcode(retcode);
|
||||
rsp_pull_weight_builder.add_reason(fbs_reason);
|
||||
rsp_pull_weight_builder.add_iteration(iteration);
|
||||
rsp_pull_weight_builder.add_iteration(SizeToInt(iteration));
|
||||
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
|
||||
auto rsp_pull_weight = rsp_pull_weight_builder.Finish();
|
||||
fbb->Finish(rsp_pull_weight);
|
||||
|
|
|
@ -42,9 +42,10 @@ class PullWeightKernel : public RoundKernel {
|
|||
bool Reset() override;
|
||||
|
||||
private:
|
||||
void PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req);
|
||||
void BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
|
||||
size_t iteration, const std::map<std::string, AddressPtr> &feature_maps);
|
||||
void PullWeight(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPullWeight *pull_weight_req);
|
||||
void BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, size_t iteration,
|
||||
const std::map<std::string, AddressPtr> &feature_maps);
|
||||
|
||||
Executor *executor_;
|
||||
|
||||
|
|
|
@ -67,12 +67,12 @@ void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageH
|
|||
return;
|
||||
}
|
||||
|
||||
ResultCode PushWeightKernel::PushWeight(std::shared_ptr<FBBuilder> fbb,
|
||||
ResultCode PushWeightKernel::PushWeight(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestPushWeight *push_weight_req) {
|
||||
if (fbb == nullptr || push_weight_req == nullptr) {
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
size_t iteration = static_cast<size_t>(push_weight_req->iteration());
|
||||
size_t iteration = IntToSize(push_weight_req->iteration());
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
if (iteration != current_iter) {
|
||||
std::string reason = "PushWeight iteration number is invalid:" + std::to_string(iteration) +
|
||||
|
@ -123,13 +123,13 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R
|
|||
return upload_feature_map;
|
||||
}
|
||||
|
||||
void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, size_t iteration) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
|
||||
rsp_push_weight_builder.add_retcode(retcode);
|
||||
rsp_push_weight_builder.add_reason(fbs_reason);
|
||||
rsp_push_weight_builder.add_iteration(iteration);
|
||||
rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
|
||||
auto rsp_push_weight = rsp_push_weight_builder.Finish();
|
||||
fbb->Finish(rsp_push_weight);
|
||||
return;
|
||||
|
|
|
@ -42,10 +42,10 @@ class PushWeightKernel : public RoundKernel {
|
|||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
ResultCode PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
ResultCode PushWeight(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
std::map<std::string, Address> ParseFeatureMap(const schema::RequestPushWeight *push_weight_req);
|
||||
void BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const schema::ResponseCode retcode, const std::string &reason,
|
||||
size_t iteration);
|
||||
void BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, size_t iteration);
|
||||
|
||||
Executor *executor_;
|
||||
uint32_t local_rank_;
|
||||
|
|
|
@ -34,7 +34,7 @@ RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), e
|
|||
// Detect whether there's any data needs to be released every 100 milliseconds.
|
||||
if (heap_data_to_release_.empty()) {
|
||||
release_lock.unlock();
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kReleaseDuration));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() {
|
|||
}
|
||||
}
|
||||
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; }
|
||||
|
||||
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
|
||||
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { return; }
|
||||
|
||||
void RoundKernel::StopTimer() const {
|
||||
if (stop_timer_cb_) {
|
||||
|
|
|
@ -38,6 +38,7 @@ namespace mindspore {
|
|||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
constexpr uint64_t kReleaseDuration = 100;
|
||||
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
|
||||
// kernels to represent the process. They receive and parse messages from the server communication module. After
|
||||
// handling these messages, round kernels allocate response data and send it back.
|
||||
|
|
|
@ -118,7 +118,7 @@ bool StartFLJobKernel::Reset() {
|
|||
}
|
||||
|
||||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||
iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
// The first startFLJob request means a new iteration starts running.
|
||||
Iteration::GetInstance().SetIterationRunning();
|
||||
|
@ -220,9 +220,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||
fl_plan_builder.add_server_mode(fbs_server_mode);
|
||||
fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num());
|
||||
fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num());
|
||||
fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size());
|
||||
fl_plan_builder.add_iterations(SizeToInt(ps::PSContext::instance()->fl_iteration_num()));
|
||||
fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num()));
|
||||
fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size()));
|
||||
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
|
||||
#ifdef ENABLE_ARMOUR
|
||||
fl_plan_builder.add_cipher(cipher_public_params);
|
||||
|
|
|
@ -90,7 +90,7 @@ bool UpdateModelKernel::Reset() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
|
||||
while (!executor_->IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
|
@ -120,7 +120,7 @@ ResultCode UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr
|
|||
ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb) {
|
||||
RETURN_IF_NULL(update_model_req, ResultCode::kSuccessAndReturn);
|
||||
size_t iteration = static_cast<size_t>(update_model_req->iteration());
|
||||
size_t iteration = IntToSize(update_model_req->iteration());
|
||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
|
||||
|
|
|
@ -281,16 +281,16 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
|
|||
KernelParams aggr_params = {};
|
||||
|
||||
const std::vector<std::string> &input_names = aggr_kernel->input_names();
|
||||
std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
(void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names();
|
||||
std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
(void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &output_names = aggr_kernel->output_names();
|
||||
std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
(void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
|
||||
aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
|
||||
|
@ -304,16 +304,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
|
|||
KernelParams optimizer_params = {};
|
||||
|
||||
const std::vector<std::string> &input_names = optimizer_kernel->input_names();
|
||||
std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
(void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
|
||||
std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
(void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
const std::vector<std::string> &output_names = optimizer_kernel->output_names();
|
||||
std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
(void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
|
||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
||||
|
||||
optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
|
||||
return true;
|
||||
|
|
|
@ -34,8 +34,8 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
|
|||
threshold_count_(threshold_count),
|
||||
server_num_as_threshold_(server_num_as_threshold) {}
|
||||
|
||||
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb) {
|
||||
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
|
||||
const FinishIterCb &finish_iteration_cb) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
|
||||
|
@ -50,7 +50,7 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
|||
};
|
||||
|
||||
// Callback for finalizing the server. This can only be called once.
|
||||
finalize_cb_ = [&](void) -> void { communicator_->Stop(); };
|
||||
finalize_cb_ = [&](void) -> void { (void)communicator_->Stop(); };
|
||||
|
||||
if (check_timeout_) {
|
||||
iter_timer_ = std::make_shared<IterationTimer>();
|
||||
|
@ -116,7 +116,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
|||
if (Server::GetInstance().IsSafeMode()) {
|
||||
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
|
||||
std::string reason = "The cluster is in safemode.";
|
||||
communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
(void)communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -128,10 +128,10 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
|||
if (output->size == 0) {
|
||||
std::string reason = "The output of the round " + name_ + " is empty.";
|
||||
MS_LOG(WARNING) << reason;
|
||||
communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
(void)communicator_->SendResponse(reason.c_str(), reason.size(), message);
|
||||
return;
|
||||
}
|
||||
communicator_->SendResponse(output->addr, output->size, message);
|
||||
(void)communicator_->SendResponse(output->addr, output->size, message);
|
||||
kernel_->Release(output);
|
||||
|
||||
// Must send response back no matter what value Launch method returns.
|
||||
|
@ -142,7 +142,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
|||
return;
|
||||
}
|
||||
|
||||
void Round::Reset() { kernel_->Reset(); }
|
||||
void Round::Reset() { (void)kernel_->Reset(); }
|
||||
|
||||
const std::string &Round::name() const { return name_; }
|
||||
|
||||
|
|
|
@ -37,8 +37,8 @@ class Round {
|
|||
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
|
||||
~Round() = default;
|
||||
|
||||
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb);
|
||||
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
|
||||
const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
// Reinitialize count service and round kernel of this round after scaling operations are done.
|
||||
bool ReInitForScaling(uint32_t server_num);
|
||||
|
|
|
@ -102,7 +102,7 @@ void Server::CancelSafeMode() {
|
|||
safemode_ = false;
|
||||
}
|
||||
|
||||
bool Server::IsSafeMode() { return safemode_.load(); }
|
||||
bool Server::IsSafeMode() const { return safemode_.load(); }
|
||||
|
||||
void Server::InitServerContext() {
|
||||
ps::PSContext::instance()->GenerateResetterRound();
|
||||
|
@ -121,7 +121,7 @@ void Server::InitServerContext() {
|
|||
void Server::InitCluster() {
|
||||
server_node_ = std::make_shared<ps::core::ServerNode>();
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
task_executor_ = std::make_shared<ps::core::TaskExecutor>(32);
|
||||
task_executor_ = std::make_shared<ps::core::TaskExecutor>(kExecutorThreadPoolSize);
|
||||
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||
if (!InitCommunicatorWithServer()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
|
||||
|
@ -235,9 +235,9 @@ void Server::InitCipher() {
|
|||
#ifdef ENABLE_ARMOUR
|
||||
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||
|
||||
int cipher_t = cipher_reconstruct_secrets_down_cnt_;
|
||||
int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
|
||||
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
||||
int cipher_g = 1;
|
||||
const int cipher_g = 1;
|
||||
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
|
||||
float dp_eps = ps::PSContext::instance()->dp_eps();
|
||||
float dp_delta = ps::PSContext::instance()->dp_delta();
|
||||
|
@ -304,8 +304,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
|
|||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
safemode_ = true;
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
|
||||
(void)communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
||||
|
@ -314,8 +314,8 @@ void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpC
|
|||
"network building phase.";
|
||||
safemode_ = true;
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
|
||||
(void)communicator_with_server_->Stop();
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -363,7 +363,10 @@ void Server::StartCommunicator() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Start communicator with server.";
|
||||
communicator_with_server_->Start();
|
||||
if (!communicator_with_server_->Start()) {
|
||||
MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
|
||||
return;
|
||||
}
|
||||
DistributedMetadataStore::GetInstance().Initialize(server_node_);
|
||||
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
||||
DistributedCountService::GetInstance().Initialize(server_node_, kLeaderServerRank);
|
||||
|
@ -371,7 +374,11 @@ void Server::StartCommunicator() {
|
|||
|
||||
MS_LOG(INFO) << "Start communicator with worker.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
if (!communicator->Start()) {
|
||||
MS_LOG(EXCEPTION) << "Starting communicator with worker failed.";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Server::ProcessBeforeScalingOut() {
|
||||
|
@ -405,7 +412,7 @@ void Server::ProcessAfterScalingOut() {
|
|||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
||||
safemode_ = false;
|
||||
}
|
||||
|
||||
|
@ -418,7 +425,7 @@ void Server::ProcessAfterScalingIn() {
|
|||
if (server_node_->rank_id() == UINT32_MAX) {
|
||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { (void)communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
return;
|
||||
}
|
||||
|
@ -439,7 +446,7 @@ void Server::ProcessAfterScalingIn() {
|
|||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
||||
safemode_ = false;
|
||||
}
|
||||
} // namespace server
|
||||
|
|
|
@ -33,6 +33,9 @@
|
|||
namespace mindspore {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// The sleeping time of the server thread before the networking is completed.
|
||||
constexpr uint32_t kServerSleepTimeForNetworking = 1000;
|
||||
|
||||
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
||||
class Server {
|
||||
public:
|
||||
|
@ -51,7 +54,7 @@ class Server {
|
|||
|
||||
void SwitchToSafeMode();
|
||||
void CancelSafeMode();
|
||||
bool IsSafeMode();
|
||||
bool IsSafeMode() const;
|
||||
|
||||
private:
|
||||
Server()
|
||||
|
@ -162,8 +165,6 @@ class Server {
|
|||
uint32_t server_num_;
|
||||
uint32_t worker_num_;
|
||||
uint16_t fl_server_port_;
|
||||
size_t start_fl_job_cnt_;
|
||||
size_t update_model_cnt_;
|
||||
size_t cipher_initial_client_cnt_;
|
||||
size_t cipher_exchange_secrets_cnt_;
|
||||
size_t cipher_share_secrets_cnt_;
|
||||
|
@ -171,9 +172,6 @@ class Server {
|
|||
size_t cipher_reconstruct_secrets_up_cnt_;
|
||||
size_t cipher_reconstruct_secrets_down_cnt_;
|
||||
uint64_t cipher_time_window_;
|
||||
|
||||
float percent_for_update_model_;
|
||||
float percent_for_get_model_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
|
|
|
@ -70,8 +70,14 @@ void FLWorker::Run() {
|
|||
|
||||
void FLWorker::Finalize() {
|
||||
MS_EXCEPTION_IF_NULL(worker_node_);
|
||||
worker_node_->Finish();
|
||||
worker_node_->Stop();
|
||||
if (!worker_node_->Finish()) {
|
||||
MS_LOG(ERROR) << "Worker node finishing failed.";
|
||||
return;
|
||||
}
|
||||
if (!worker_node_->Stop()) {
|
||||
MS_LOG(ERROR) << "Worker node stopping failed.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
|
||||
|
@ -201,8 +207,8 @@ void FLWorker::ProcessAfterScalingOut() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Cluster scaling out completed. Reinitialize for worker.";
|
||||
server_num_ = worker_node_->server_num();
|
||||
worker_num_ = worker_node_->worker_num();
|
||||
server_num_ = IntToUint(worker_node_->server_num());
|
||||
worker_num_ = IntToUint(worker_node_->worker_num());
|
||||
MS_LOG(INFO) << "After scheduler scaling out, worker number is " << worker_num_ << ", server number is "
|
||||
<< server_num_ << ". Exit safemode.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
|
||||
|
@ -215,8 +221,8 @@ void FLWorker::ProcessAfterScalingIn() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Cluster scaling in completed. Reinitialize for worker.";
|
||||
server_num_ = worker_node_->server_num();
|
||||
worker_num_ = worker_node_->worker_num();
|
||||
server_num_ = IntToUint(worker_node_->server_num());
|
||||
worker_num_ = IntToUint(worker_node_->worker_num());
|
||||
MS_LOG(INFO) << "After scheduler scaling in, worker number is " << worker_num_ << ", server number is " << server_num_
|
||||
<< ". Exit safemode.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerSleepTimeForNetworking));
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#endif
|
||||
|
@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() {
|
|||
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
|
||||
dynamic_shape_indices_ = true;
|
||||
}
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
|
||||
|
@ -637,7 +637,7 @@ Status GatherPInfo::InferBias() {
|
|||
rank = rank % (params_strategy[0] * params_strategy[1]);
|
||||
}
|
||||
}
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
|
||||
return SUCCESS;
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
|
@ -119,7 +119,7 @@ std::vector<StrategyPtr> UniqueInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
|
@ -156,7 +156,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
#endif
|
||||
|
||||
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
|
|
|
@ -47,7 +47,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "mindspore/core/utils/parallel_node_check.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
#include "vm/transform.h"
|
||||
#include "parse/python_adapter.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/parameter_server.h"
|
||||
#include "ps/scheduler.h"
|
||||
#include "ps/worker.h"
|
||||
|
@ -478,7 +478,7 @@ bool OptInlineAction(const ResourcePtr &res) {
|
|||
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
|
||||
|
||||
bool VmOptimizeAction(const ResourcePtr &res) {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps});
|
||||
}
|
||||
|
@ -633,8 +633,8 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
bool StartPSWorkerAction(const ResourcePtr &res) {
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
bool StartPSWorkerAction(const ResourcePtr &) {
|
||||
ps::Worker::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
@ -695,7 +695,7 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool StartPSSchedulerAction(const ResourcePtr &res) {
|
||||
bool StartPSSchedulerAction(const ResourcePtr &) {
|
||||
ps::Scheduler::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
@ -861,7 +861,7 @@ std::vector<ActionItem> VmPipeline() {
|
|||
actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction));
|
||||
|
||||
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
std::string server_mode = ps::PSContext::instance()->server_mode();
|
||||
if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
|
||||
|
@ -889,7 +889,7 @@ std::vector<ActionItem> BackendPipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
std::vector<ActionItem> ServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
#else
|
||||
#include "runtime/device/gpu/distribution/collective_fake_init.h"
|
||||
#endif
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/util.h"
|
||||
#endif
|
||||
#include "ps/ps_context.h"
|
||||
|
|
|
@ -47,7 +47,7 @@
|
|||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/parameter_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
#endif
|
||||
|
@ -211,7 +211,7 @@ namespace {
|
|||
bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
|
||||
|
||||
bool parallel_mode() {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -556,7 +556,7 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::str
|
|||
}
|
||||
|
||||
std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip,
|
||||
std::int16_t scheduler_port, uint32_t worker_num,
|
||||
uint16_t scheduler_port, uint32_t worker_num,
|
||||
uint32_t server_num,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor) {
|
||||
std::lock_guard<std::mutex> lock(communicator_mutex_);
|
||||
|
|
|
@ -61,7 +61,7 @@ class ServerNode : public AbstractNode {
|
|||
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, uint16_t port,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor);
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port,
|
||||
std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, uint16_t scheduler_port,
|
||||
uint32_t worker_num, uint32_t server_num,
|
||||
const std::shared_ptr<TaskExecutor> &task_executor);
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#endif
|
||||
|
@ -63,7 +63,7 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
}
|
||||
|
||||
bool PSContext::is_ps_mode() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return true;
|
||||
}
|
||||
return ps_enabled_;
|
||||
|
@ -74,7 +74,7 @@ void PSContext::Reset() {
|
|||
is_worker_ = false;
|
||||
is_pserver_ = false;
|
||||
is_sched_ = false;
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
ps_cache_instance.Finalize();
|
||||
set_cache_enable(false);
|
||||
|
@ -83,7 +83,7 @@ void PSContext::Reset() {
|
|||
}
|
||||
|
||||
std::string PSContext::ms_role() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_;
|
||||
}
|
||||
if (is_worker_) {
|
||||
|
@ -98,21 +98,21 @@ std::string PSContext::ms_role() const {
|
|||
}
|
||||
|
||||
bool PSContext::is_worker() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_ == kEnvRoleOfWorker;
|
||||
}
|
||||
return is_worker_;
|
||||
}
|
||||
|
||||
bool PSContext::is_server() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_ == kEnvRoleOfServer;
|
||||
}
|
||||
return is_pserver_;
|
||||
}
|
||||
|
||||
bool PSContext::is_scheduler() const {
|
||||
if (server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) {
|
||||
if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
|
||||
return role_ == kEnvRoleOfScheduler;
|
||||
}
|
||||
return is_sched_;
|
||||
|
@ -130,44 +130,44 @@ uint32_t PSContext::ps_rank_id() const { return rank_id_; }
|
|||
|
||||
void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size,
|
||||
size_t vocab_size) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
|
||||
size_t cache_vocab_size, size_t embedding_size) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_cache_enable(bool cache_enable) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
|
||||
#endif
|
||||
}
|
||||
|
||||
void PSContext::set_rank_id(uint32_t rank_id) const {
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.set_rank_id(rank_id);
|
||||
#endif
|
||||
}
|
||||
|
@ -358,7 +358,7 @@ void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
|
|||
uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
|
||||
|
||||
void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
|
||||
if (reconstruct_secrets_threshold <= 0) {
|
||||
if (reconstruct_secrets_threshold == 0) {
|
||||
MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -136,7 +136,8 @@ bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Util::DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) {
|
||||
void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
|
||||
const std::string &fused_cnode_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
|
||||
|
|
|
@ -56,7 +56,8 @@ class Util {
|
|||
static bool FuseServerCommOps(const pipeline::ResourcePtr &res);
|
||||
|
||||
private:
|
||||
static void DoFusion(FuncGraphPtr func_graph, const std::string &cnode_name, const std::string &fused_cnode_name);
|
||||
static void DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
|
||||
const std::string &fused_cnode_name);
|
||||
static kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list);
|
||||
|
||||
static std::unordered_map<std::string, int64_t> optimizer_to_ids;
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "debug/env_config_parser.h"
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
|
@ -333,7 +333,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
}
|
||||
add_need_alloc_nodes(input_node);
|
||||
}
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
bool ps_cache_check = false;
|
||||
#endif
|
||||
for (auto &item : need_alloc_nodes) {
|
||||
|
@ -346,7 +346,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
continue;
|
||||
}
|
||||
DeviceAddressPtr device_address = nullptr;
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
const std::string ¶m_name = item->fullname_with_scope();
|
||||
if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
MS_LOG(INFO) << "Parameter(" << param_name << ")"
|
||||
|
@ -1087,7 +1087,7 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
|
|||
}
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
||||
AnfNodePtr *const first_cache_input_index,
|
||||
size_t *const first_cache_size) {
|
||||
|
|
|
@ -817,6 +817,7 @@ def reset_ps_context():
|
|||
"""
|
||||
_reset_ps_context()
|
||||
|
||||
|
||||
def set_fl_context(**kwargs):
|
||||
"""
|
||||
Set federated learning training mode context.
|
||||
|
|
|
@ -726,6 +726,7 @@ class Pull(PrimitiveWithInfer):
|
|||
def infer_dtype(self, key_dtype, weight_dtype):
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class PullWeight(PrimitiveWithInfer):
|
||||
"""
|
||||
Pull weight by its names from server.
|
||||
|
@ -751,6 +752,7 @@ class PullWeight(PrimitiveWithInfer):
|
|||
def infer_dtype(self, weight, name, index):
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class PushWeight(PrimitiveWithInfer):
|
||||
"""
|
||||
Upload weight by its names to server.
|
||||
|
@ -776,6 +778,7 @@ class PushWeight(PrimitiveWithInfer):
|
|||
def infer_dtype(self, weight, ps_key, index):
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class identity(Primitive):
|
||||
"""
|
||||
Makes a identify primitive, used for pynative mode.
|
||||
|
|
|
@ -31,6 +31,7 @@ _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
|
|||
|
||||
_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"]
|
||||
|
||||
|
||||
def ps_context():
|
||||
"""
|
||||
Get the global _ps_context, if it is not created, create a new one.
|
||||
|
@ -226,6 +227,7 @@ def _set_cache_enable(cache_enable):
|
|||
def _set_rank_id(rank_id):
|
||||
ps_context().set_rank_id(rank_id)
|
||||
|
||||
|
||||
def _check_value(key, value):
|
||||
"""
|
||||
Validate the value for parameter server context keys.
|
||||
|
|
Loading…
Reference in New Issue