!19689 Fix fl namespace issue.

Merge pull request !19689 from ZPaC/fix-namespace
This commit is contained in:
i-robot 2021-07-09 08:02:48 +00:00 committed by Gitee
commit d76bb99d8a
85 changed files with 575 additions and 566 deletions

View File

@ -47,13 +47,13 @@ class FusedPullWeightKernel : public CPUKernel {
return false; return false;
} }
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>(); std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb); MS_EXCEPTION_IF_NULL(fbb);
total_iteration_++; total_iteration_++;
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() != if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
ps::kTrainBeginStepNum) { fl::kTrainBeginStepNum) {
return true; return true;
} }
@ -72,10 +72,10 @@ class FusedPullWeightKernel : public CPUKernel {
const schema::ResponsePullWeight *pull_weight_rsp = nullptr; const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady; int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) { while (retcode == schema::ResponseCode_SucNotReady) {
if (!ps::worker::FLWorker::GetInstance().SendToServer( if (!fl::worker::FLWorker::GetInstance().SendToServer(
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) { 0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped."; MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
ps::worker::FLWorker::GetInstance().SetIterationRunning(); fl::worker::FLWorker::GetInstance().SetIterationRunning();
return true; return true;
} }
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg); MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
@ -116,7 +116,7 @@ class FusedPullWeightKernel : public CPUKernel {
} }
} }
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_; MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationRunning(); fl::worker::FLWorker::GetInstance().SetIterationRunning();
return true; return true;
} }
@ -154,7 +154,7 @@ class FusedPullWeightKernel : public CPUKernel {
void InitSizeLists() { return; } void InitSizeLists() { return; }
private: private:
bool BuildPullWeightReq(std::shared_ptr<ps::FBBuilder> fbb) { bool BuildPullWeightReq(std::shared_ptr<fl::FBBuilder> fbb) {
MS_EXCEPTION_IF_NULL(fbb); MS_EXCEPTION_IF_NULL(fbb);
std::vector<flatbuffers::Offset<flatbuffers::String>> fbs_weight_names; std::vector<flatbuffers::Offset<flatbuffers::String>> fbs_weight_names;
for (const std::string &weight_name : weight_full_names_) { for (const std::string &weight_name : weight_full_names_) {

View File

@ -45,13 +45,13 @@ class FusedPushWeightKernel : public CPUKernel {
return false; return false;
} }
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>(); std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb); MS_EXCEPTION_IF_NULL(fbb);
total_iteration_++; total_iteration_++;
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() != if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
ps::kTrainBeginStepNum) { fl::kTrainBeginStepNum) {
return true; return true;
} }
@ -67,17 +67,17 @@ class FusedPushWeightKernel : public CPUKernel {
} }
// The server number may change after scaling in/out. // The server number may change after scaling in/out.
for (uint32_t i = 0; i < ps::worker::FLWorker::GetInstance().server_num(); i++) { for (uint32_t i = 0; i < fl::worker::FLWorker::GetInstance().server_num(); i++) {
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
const schema::ResponsePushWeight *push_weight_rsp = nullptr; const schema::ResponsePushWeight *push_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady; int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) { while (retcode == schema::ResponseCode_SucNotReady) {
if (!ps::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(), if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
ps::core::TcpUserCommand::kPushWeight, ps::core::TcpUserCommand::kPushWeight,
&push_weight_rsp_msg)) { &push_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
<< " failed. This iteration is dropped."; << " failed. This iteration is dropped.";
ps::worker::FLWorker::GetInstance().SetIterationCompleted(); fl::worker::FLWorker::GetInstance().SetIterationCompleted();
return true; return true;
} }
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg); MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
@ -105,7 +105,7 @@ class FusedPushWeightKernel : public CPUKernel {
} }
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_; MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationCompleted(); fl::worker::FLWorker::GetInstance().SetIterationCompleted();
return true; return true;
} }
@ -143,7 +143,7 @@ class FusedPushWeightKernel : public CPUKernel {
void InitSizeLists() { return; } void InitSizeLists() { return; }
private: private:
bool BuildPushWeightReq(std::shared_ptr<ps::FBBuilder> fbb, const std::vector<AddressPtr> &weights) { bool BuildPushWeightReq(std::shared_ptr<fl::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
for (size_t i = 0; i < weight_full_names_.size(); i++) { for (size_t i = 0; i < weight_full_names_.size(); i++) {
const std::string &weight_name = weight_full_names_[i]; const std::string &weight_name = weight_full_names_[i];

View File

@ -31,8 +31,8 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
int return_num = 0; int return_num = 0;
cipher_meta_storage_.RegisterClass(); cipher_meta_storage_.RegisterClass();
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN); const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
cipher_meta_storage_.RegisterPrime(ps::server::kCtxCipherPrimer, new_prime); cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
if (!cipher_meta_storage_.GetPrimeFromServer(ps::server::kCtxCipherPrimer, publicparam_.prime)) { if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
MS_LOG(ERROR) << "Cipher Param Update is invalid."; MS_LOG(ERROR) << "Cipher Param Update is invalid.";
return false; return false;
} }
@ -45,7 +45,7 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
publicparam_.t = param.t; publicparam_.t = param.t;
secrets_minnums_ = param.t; secrets_minnums_ = param.t;
client_num_need_ = cipher_initial_client_cnt; client_num_need_ = cipher_initial_client_cnt;
featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float); featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
share_clients_num_need_ = cipher_share_secrets_cnt; share_clients_num_need_ = cipher_share_secrets_cnt;
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1; reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
get_model_num_need_ = cipher_get_clientlist_cnt; get_model_num_need_ = cipher_get_clientlist_cnt;

View File

@ -21,7 +21,7 @@ namespace mindspore {
namespace armour { namespace armour {
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time, bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req, const schema::GetExchangeKeys *get_exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder) { std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder) {
MS_LOG(INFO) << "CipherMgr::GetKeys START"; MS_LOG(INFO) << "CipherMgr::GetKeys START";
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) { if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr."; MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
@ -32,7 +32,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
// get clientlist from memory server. // get clientlist from memory server.
std::vector<std::string> clients; std::vector<std::string> clients;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &clients); cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
size_t cur_clients_num = clients.size(); size_t cur_clients_num = clients.size();
std::string fl_id = get_exchange_keys_req->fl_id()->str(); std::string fl_id = get_exchange_keys_req->fl_id()->str();
@ -61,7 +61,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time, bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req, const schema::RequestExchangeKeys *exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder) { std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder) {
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START"; MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
// step 0: judge if the input param is legal. // step 0: judge if the input param is legal.
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) { if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
@ -75,8 +75,8 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
// step 1: get clientlist and client keys from memory server. // step 1: get clientlist and client keys from memory server.
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys; std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
std::vector<std::string> client_list; std::vector<std::string> client_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &client_list); cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys); cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
// step2: process new item data. and update new item data to memory server. // step2: process new item data. and update new item data to memory server.
size_t cur_clients_num = client_list.size(); size_t cur_clients_num = client_list.size();
@ -131,9 +131,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
cur_public_key.push_back(spk); cur_public_key.push_back(spk);
bool retcode_key = bool retcode_key =
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(ps::server::kCtxClientsKeys, fl_id, cur_public_key); cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
bool retcode_client = bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id); cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
if (retcode_key && retcode_client) { if (retcode_key && retcode_client) {
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success"; MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
@ -147,7 +147,7 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
} }
} }
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder, void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
const schema::ResponseCode retcode, const std::string &reason, const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const int iteration) { const std::string &next_req_time, const int iteration) {
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason); auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
@ -162,7 +162,7 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exc
return; return;
} }
bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode, bool CipherKeys::BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) { const int iteration, const std::string &next_req_time, bool is_good) {
bool flag = true; bool flag = true;
if (is_good) { if (is_good) {
@ -170,7 +170,7 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list; std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
MS_LOG(INFO) << "Get Keys: "; MS_LOG(INFO) << "Get Keys: ";
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys; std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys); cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
if (record_public_keys.size() < cipher_init_->client_num_need_) { if (record_public_keys.size() < cipher_init_->client_num_need_) {
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size() MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
<< "clients num: " << cipher_init_->client_num_need_; << "clients num: " << cipher_init_->client_num_need_;
@ -221,8 +221,8 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
} }
void CipherKeys::ClearKeys() { void CipherKeys::ClearKeys() {
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxExChangeKeysClientList); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxExChangeKeysClientList);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsKeys); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsKeys);
} }
} // namespace armour } // namespace armour

View File

@ -45,18 +45,18 @@ class CipherKeys {
// handle the client's request of get keys. // handle the client's request of get keys.
bool GetKeys(const int cur_iterator, const std::string &next_req_time, bool GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req, const schema::GetExchangeKeys *get_exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder); std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder);
// handle the client's request of exchange keys. // handle the client's request of exchange keys.
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time, bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req, const schema::RequestExchangeKeys *exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder); std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder);
// build response code of get keys. // build response code of get keys.
bool BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration, bool BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
const std::string &next_req_time, bool is_good); const std::string &next_req_time, bool is_good);
// build response code of exchange keys. // build response code of exchange keys.
void BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder, void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
const schema::ResponseCode retcode, const std::string &reason, const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const int iteration); const std::string &next_req_time, const int iteration);
// clear the shared memory. // clear the shared memory.

View File

@ -21,16 +21,16 @@ namespace armour {
void CipherMetaStorage::GetClientSharesFromServer( void CipherMetaStorage::GetClientSharesFromServer(
const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) { const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
const ps::PBMetadata &clients_shares_pb_out = const fl::PBMetadata &clients_shares_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares(); const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
auto iter = clients_shares_pb.client_secret_shares().begin(); auto iter = clients_shares_pb.client_secret_shares().begin();
for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) { for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
std::string fl_id = iter->first; std::string fl_id = iter->first;
const ps::SharesPb &shares_pb = iter->second; const fl::SharesPb &shares_pb = iter->second;
std::vector<clientshare_str> encrpted_shares_new; std::vector<clientshare_str> encrpted_shares_new;
for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) { for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) {
const ps::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares); const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
clientshare_str new_clientshare; clientshare_str new_clientshare;
new_clientshare.fl_id = client_share_str_pb.fl_id(); new_clientshare.fl_id = client_share_str_pb.fl_id();
new_clientshare.index = client_share_str_pb.index(); new_clientshare.index = client_share_str_pb.index();
@ -42,8 +42,8 @@ void CipherMetaStorage::GetClientSharesFromServer(
} }
void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) { void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
const ps::PBMetadata &client_list_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list(); const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) { for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
std::string fl_id = client_list_pb.fl_id(i); std::string fl_id = client_list_pb.fl_id(i);
clients_list->push_back(fl_id); clients_list->push_back(fl_id);
@ -52,14 +52,14 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
void CipherMetaStorage::GetClientKeysFromServer( void CipherMetaStorage::GetClientKeysFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) { const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
const ps::PBMetadata &clients_keys_pb_out = const fl::PBMetadata &clients_keys_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys(); const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) { for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i); // const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
std::string fl_id = iter->first; std::string fl_id = iter->first;
ps::KeysPb keys_pb = iter->second; fl::KeysPb keys_pb = iter->second;
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end()); std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end()); std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
std::vector<std::vector<unsigned char>> cur_keys; std::vector<std::vector<unsigned char>> cur_keys;
@ -70,9 +70,9 @@ void CipherMetaStorage::GetClientKeysFromServer(
} }
bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) { bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
const ps::PBMetadata &clients_noises_pb_out = const fl::PBMetadata &clients_noises_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises(); const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
while (clients_noises_pb.has_one_client_noises() == false) { while (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "GetClientNoisesFromServer NULL."; MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
std::this_thread::sleep_for(std::chrono::milliseconds(50)); std::this_thread::sleep_for(std::chrono::milliseconds(50));
@ -83,8 +83,8 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
} }
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) { bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name); const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
ps::Prime prime_pb(prime_pb_out.prime()); fl::Prime prime_pb(prime_pb_out.prime());
std::string str = *(prime_pb.mutable_prime()); std::string str = *(prime_pb.mutable_prime());
MS_LOG(INFO) << "get prime from metastorage :" << str; MS_LOG(INFO) << "get prime from metastorage :" << str;
@ -99,20 +99,20 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) { bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
bool retcode = true; bool retcode = true;
ps::FLId fl_id_pb; fl::FLId fl_id_pb;
fl_id_pb.set_fl_id(fl_id); fl_id_pb.set_fl_id(fl_id);
ps::PBMetadata client_pb; fl::PBMetadata client_pb;
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb); client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb); retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
return retcode; return retcode;
} }
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) { void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
MS_LOG(INFO) << "register prime: " << prime; MS_LOG(INFO) << "register prime: " << prime;
ps::Prime prime_id_pb; fl::Prime prime_id_pb;
prime_id_pb.set_prime(prime); prime_id_pb.set_prime(prime);
ps::PBMetadata prime_pb; fl::PBMetadata prime_pb;
prime_pb.mutable_prime()->MergeFrom(prime_id_pb); prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb); fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
} }
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id, bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
@ -123,25 +123,25 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
return false; return false;
} }
// update new item to memory server. // update new item to memory server.
ps::KeysPb keys; fl::KeysPb keys;
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end()); keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end()); keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
ps::PairClientKeys pair_client_keys_pb; fl::PairClientKeys pair_client_keys_pb;
pair_client_keys_pb.set_fl_id(fl_id); pair_client_keys_pb.set_fl_id(fl_id);
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys); pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
ps::PBMetadata client_and_keys_pb; fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb); client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb); retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode; return retcode;
} }
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) { bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
// update new item to memory server. // update new item to memory server.
ps::OneClientNoises noises_pb; fl::OneClientNoises noises_pb;
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()}; *noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
ps::PBMetadata client_noises_pb; fl::PBMetadata client_noises_pb;
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb); client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
return ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb); return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
} }
bool CipherMetaStorage::UpdateClientShareToServer( bool CipherMetaStorage::UpdateClientShareToServer(
@ -149,10 +149,10 @@ bool CipherMetaStorage::UpdateClientShareToServer(
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) { const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
bool retcode = true; bool retcode = true;
int size_shares = shares->size(); int size_shares = shares->size();
ps::SharesPb shares_pb; fl::SharesPb shares_pb;
for (int index = 0; index < size_shares; ++index) { for (int index = 0; index < size_shares; ++index) {
// new item // new item
ps::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs(); fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
std::string fl_id_new = (*shares)[index]->fl_id()->str(); std::string fl_id_new = (*shares)[index]->fl_id()->str();
int index_new = (*shares)[index]->index(); int index_new = (*shares)[index]->index();
auto share = (*shares)[index]->share(); auto share = (*shares)[index]->share();
@ -160,32 +160,32 @@ bool CipherMetaStorage::UpdateClientShareToServer(
client_share_str_new_p->set_fl_id(fl_id_new); client_share_str_new_p->set_fl_id(fl_id_new);
client_share_str_new_p->set_index(index_new); client_share_str_new_p->set_index(index_new);
} }
ps::PairClientShares pair_client_shares_pb; fl::PairClientShares pair_client_shares_pb;
pair_client_shares_pb.set_fl_id(fl_id); pair_client_shares_pb.set_fl_id(fl_id);
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb); pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
ps::PBMetadata client_and_shares_pb; fl::PBMetadata client_and_shares_pb;
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb); client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb); retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
return retcode; return retcode;
} }
void CipherMetaStorage::RegisterClass() { void CipherMetaStorage::RegisterClass() {
ps::PBMetadata exchange_kyes_client_list; fl::PBMetadata exchange_kyes_client_list;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxExChangeKeysClientList, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
exchange_kyes_client_list); exchange_kyes_client_list);
ps::PBMetadata clients_keys; fl::PBMetadata clients_keys;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsKeys, clients_keys); fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
ps::PBMetadata reconstruct_client_list; fl::PBMetadata reconstruct_client_list;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxReconstructClientList, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxReconstructClientList,
reconstruct_client_list); reconstruct_client_list);
ps::PBMetadata clients_reconstruct_shares; fl::PBMetadata clients_reconstruct_shares;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsReconstructShares, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsReconstructShares,
clients_reconstruct_shares); clients_reconstruct_shares);
ps::PBMetadata share_secretes_client_list; fl::PBMetadata share_secretes_client_list;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxShareSecretsClientList, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
share_secretes_client_list); share_secretes_client_list);
ps::PBMetadata clients_encrypt_shares; fl::PBMetadata clients_encrypt_shares;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsEncryptedShares, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
clients_encrypt_shares); clients_encrypt_shares);
} }
} // namespace armour } // namespace armour

View File

@ -101,15 +101,15 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START"; MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
bool retcode = true; bool retcode = true;
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori; std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&reconstruct_secret_list_ori); &reconstruct_secret_list_ori);
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys; std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys); cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
std::vector<std::string> clients_reconstruct_list; std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list); &clients_reconstruct_list);
std::vector<std::string> clients_share_list; std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list); &clients_share_list);
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() || if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
record_public_keys.size() < cipher_init_->client_num_need_ || record_public_keys.size() < cipher_init_->client_num_need_ ||
@ -146,7 +146,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
client_keys.clear(); client_keys.clear();
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server"; MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(ps::server::kCtxClientNoises, noise) == false) if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
return false; return false;
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success"; MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
} else { } else {
@ -159,7 +159,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
// reconstruct secrets // reconstruct secrets
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req, const schema::SendReconstructSecret *reconstruct_secret_req,
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder, std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list) { const std::vector<std::string> &client_list) {
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START"; MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
clock_t start_time = clock(); clock_t start_time = clock();
@ -178,10 +178,10 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
return false; return false;
} }
std::vector<std::string> clients_reconstruct_list; std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list); &clients_reconstruct_list);
std::map<std::string, std::vector<clientshare_str>> clients_shares_all; std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&clients_shares_all); &clients_shares_all);
size_t count_client_num = clients_shares_all.size(); size_t count_client_num = clients_shares_all.size();
@ -215,9 +215,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
} }
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares(); auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
bool retcode_client = bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxReconstructClientList, fl_id); cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer( bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
ps::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares); fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
if (!(retcode_client && retcode_share)) { if (!(retcode_client && retcode_share)) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime, BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"reconstruct update shares or client failed.", cur_iterator, next_req_time); "reconstruct update shares or client failed.", cur_iterator, next_req_time);
@ -233,9 +233,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
return true; return true;
} else { } else {
bool retcode_result = true; bool retcode_result = true;
const ps::PBMetadata &clients_noises_pb_out = const fl::PBMetadata &clients_noises_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(ps::server::kCtxClientNoises); fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises(); const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
if (clients_noises_pb.has_one_client_noises() == false) { if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "Success,the secret will be reconstructed."; MS_LOG(INFO) << "Success,the secret will be reconstructed.";
retcode_result = ReconstructSecretsGenNoise(client_list); retcode_result = ReconstructSecretsGenNoise(client_list);
@ -279,13 +279,13 @@ bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
void CipherReconStruct::ClearReconstructSecrets() { void CipherReconStruct::ClearReconstructSecrets() {
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets START"; MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets START";
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxReconstructClientList); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxReconstructClientList);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsReconstructShares); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsReconstructShares);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientNoises); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientNoises);
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success"; MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
} }
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb, void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb,
const schema::ResponseCode retcode, const std::string &reason, const schema::ResponseCode retcode, const std::string &reason,
const int iteration, const std::string &next_req_time) { const int iteration, const std::string &next_req_time) {
auto fbs_reason = fbb->CreateString(reason); auto fbs_reason = fbb->CreateString(reason);

View File

@ -44,11 +44,11 @@ class CipherReconStruct {
// reconstruct secret mask // reconstruct secret mask
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req, const schema::SendReconstructSecret *reconstruct_secret_req,
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder, std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list); const std::vector<std::string> &client_list);
// build response code of reconstruct secret. // build response code of reconstruct secret.
void BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode, void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason, const int iteration, const std::string &next_req_time); const std::string &reason, const int iteration, const std::string &next_req_time);
// clear the shared memory. // clear the shared memory.

View File

@ -21,7 +21,7 @@
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req, bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const string next_req_time) { const string next_req_time) {
MS_LOG(INFO) << "CipherShares::ShareSecrets START"; MS_LOG(INFO) << "CipherShares::ShareSecrets START";
if (share_secrets_req == nullptr) { if (share_secrets_req == nullptr) {
@ -35,13 +35,13 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
// step 1: get client list and share secrets from memory server. // step 1: get client list and share secrets from memory server.
clock_t start_time = clock(); clock_t start_time = clock();
std::vector<std::string> clients_share_list; std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list); &clients_share_list);
std::vector<std::string> clients_exchange_list; std::vector<std::string> clients_exchange_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
&clients_exchange_list); &clients_exchange_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all; std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all); &encrypted_shares_all);
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size() MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
@ -75,9 +75,9 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares = const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
(share_secrets_req->encrypted_shares()); (share_secrets_req->encrypted_shares());
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer( bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
ps::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares); fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
bool retcode_client = bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxShareSecretsClientList, fl_id_src); cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
bool retcode = retcode_share && retcode_client; bool retcode = retcode_share && retcode_client;
if (retcode) { if (retcode) {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration); BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
} }
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
const std::string &next_req_time) { const std::string &next_req_time) {
MS_LOG(INFO) << "CipherShares::GetSecrets START"; MS_LOG(INFO) << "CipherShares::GetSecrets START";
clock_t start_time = clock(); clock_t start_time = clock();
@ -108,10 +108,10 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
// step 1: get client list and client shares list from memory server. // step 1: get client list and client shares list from memory server.
std::vector<std::string> clients_share_list; std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list); &clients_share_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all; std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all); &encrypted_shares_all);
int iteration = get_secrets_req->iteration(); int iteration = get_secrets_req->iteration();
size_t share_clients_num = clients_share_list.size(); size_t share_clients_num = clients_share_list.size();
@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
} }
void CipherShares::BuildGetSecretsRsp( void CipherShares::BuildGetSecretsRsp(
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration, std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) { std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
int rsp_retcode = retcode; int rsp_retcode = retcode;
int rsp_iteration = iteration; int rsp_iteration = iteration;
@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp(
return; return;
} }
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const schema::ResponseCode retcode, const string &reason, const schema::ResponseCode retcode, const string &reason,
const string &next_req_time, const int iteration) { const string &next_req_time, const int iteration) {
auto rsp_reason = share_secrets_resp_builder->CreateString(reason); auto rsp_reason = share_secrets_resp_builder->CreateString(reason);
@ -211,8 +211,8 @@ void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> s
} }
void CipherShares::ClearShareSecrets() { void CipherShares::ClearShareSecrets() {
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxShareSecretsClientList); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxShareSecretsClientList);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsEncryptedShares); fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsEncryptedShares);
} }
} // namespace armour } // namespace armour

View File

@ -43,17 +43,17 @@ class CipherShares {
// handle the client's request of share secrets. // handle the client's request of share secrets.
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req, bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, const string next_req_time); std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
// handle the client's request of get secrets. // handle the client's request of get secrets.
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req, bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time); std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
// build response code of share secrets. // build response code of share secrets.
void BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const schema::ResponseCode retcode, const string &reason, const string &next_req_time, const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
const int iteration); const int iteration);
// build response code of get secrets. // build response code of get secrets.
void BuildGetSecretsRsp(std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
const schema::ResponseCode retcode, const int iteration, std::string next_req_time, const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares); std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
// clear the shared memory. // clear the shared memory.

View File

@ -26,13 +26,13 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
clock_t start_time = clock(); clock_t start_time = clock();
std::vector<float> noise; std::vector<float> noise;
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(ps::server::kCtxClientNoises, &noise); cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (noise.size() != cipher_init_->featuremap_) { if (noise.size() != cipher_init_->featuremap_) {
MS_LOG(ERROR) << " CipherMgr UnMask ERROR"; MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
return false; return false;
} }
size_t data_size = ps::server::LocalMetaStore::GetInstance().value<size_t>(ps::server::kCtxFedAvgTotalDataSize); size_t data_size = fl::server::LocalMetaStore::GetInstance().value<size_t>(fl::server::kCtxFedAvgTotalDataSize);
int sum_size = 0; int sum_size = 0;
for (auto iter = data.begin(); iter != data.end(); ++iter) { for (auto iter = data.begin(); iter != data.end(); ++iter) {
int size_data = iter->second->size / sizeof(float); int size_data = iter->second->size / sizeof(float);

View File

@ -17,13 +17,13 @@
#include "fl/server/collective_ops_impl.h" #include "fl/server/collective_ops_impl.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) { void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node; server_node_ = server_node;
local_rank_ = server_node_->rank_id(); local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num(); server_num_ = ps::PSContext::instance()->initial_server_num();
return; return;
} }
@ -66,7 +66,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
// Step 1: Async send data to next rank. // Step 1: Async send data to next rank.
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size; size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index]; T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T)); chunk_sizes[send_chunk_index] * sizeof(T));
// Step 2: Async receive data to next rank and wait until it's done. // Step 2: Async receive data to next rank and wait until it's done.
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size; size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
@ -76,7 +76,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
std::shared_ptr<std::vector<unsigned char>> recv_str; std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) { if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false; return false;
@ -104,7 +104,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
for (size_t i = 0; i < rank_size - 1; i++) { for (size_t i = 0; i < rank_size - 1; i++) {
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size; size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index]; T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T)); chunk_sizes[send_chunk_index] * sizeof(T));
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size; size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
@ -113,7 +113,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
std::shared_ptr<std::vector<unsigned char>> recv_str; std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) { if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
@ -151,7 +151,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
for (uint32_t i = 1; i < rank_size; i++) { for (uint32_t i = 1; i < rank_size; i++) {
std::shared_ptr<std::vector<unsigned char>> recv_str; std::shared_ptr<std::vector<unsigned char>> recv_str;
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str); auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) { if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false; return false;
@ -167,7 +167,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
} }
} else { } else {
MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
if (!server_node_->Wait(send_req_id)) { if (!server_node_->Wait(send_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false; return false;
@ -180,7 +180,8 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
if (local_rank_ == 0) { if (local_rank_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) { for (uint32_t i = 1; i < rank_size; i++) {
MS_LOG(DEBUG) << "Broadcast data to process " << i; MS_LOG(DEBUG) << "Broadcast data to process " << i;
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); auto send_req_id =
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
if (!server_node_->Wait(send_req_id)) { if (!server_node_->Wait(send_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false; return false;
@ -189,7 +190,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
} else { } else {
MS_LOG(DEBUG) << "Broadcast receive from rank 0."; MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
std::shared_ptr<std::vector<unsigned char>> recv_str; std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str); auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) { if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false; return false;
@ -247,5 +248,5 @@ template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *re
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ #define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -27,7 +27,7 @@
#include "fl/server/common.h" #include "fl/server/common.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// CollectiveOpsImpl is the collective communication API of the server. // CollectiveOpsImpl is the collective communication API of the server.
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
@ -39,7 +39,7 @@ class CollectiveOpsImpl {
return instance; return instance;
} }
void Initialize(const std::shared_ptr<core::ServerNode> &server_node); void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
template <typename T> template <typename T>
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count); bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
@ -48,7 +48,7 @@ class CollectiveOpsImpl {
bool ReInitForScaling(); bool ReInitForScaling();
private: private:
CollectiveOpsImpl() = default; CollectiveOpsImpl() : server_node_(nullptr), local_rank_(0), server_num_(0) {}
~CollectiveOpsImpl() = default; ~CollectiveOpsImpl() = default;
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete; CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete; CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
@ -61,7 +61,7 @@ class CollectiveOpsImpl {
template <typename T> template <typename T>
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
std::shared_ptr<core::ServerNode> server_node_; std::shared_ptr<ps::core::ServerNode> server_node_;
uint32_t local_rank_; uint32_t local_rank_;
uint32_t server_num_; uint32_t server_num_;
@ -69,6 +69,6 @@ class CollectiveOpsImpl {
std::mutex mtx_; std::mutex mtx_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
#define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ #define MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
#include <map> #include <map>
#include <string> #include <string>
@ -37,7 +37,7 @@
#include "ps/core/communicator/message_handler.h" #include "ps/core/communicator/message_handler.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// Definitions for the server framework. // Definitions for the server framework.
enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
@ -73,7 +73,7 @@ using TimeOutCb = std::function<void(bool, const std::string &)>;
using StopTimerCb = std::function<void(void)>; using StopTimerCb = std::function<void(void)>;
using FinishIterCb = std::function<void(bool, const std::string &)>; using FinishIterCb = std::function<void(bool, const std::string &)>;
using FinalizeCb = std::function<void(void)>; using FinalizeCb = std::function<void(void)>;
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>; using MessageCallback = std::function<void(const std::shared_ptr<ps::core::MessageHandler> &)>;
// Information about whether server kernel will reuse kernel node memory from the front end. // Information about whether server kernel will reuse kernel node memory from the front end.
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
@ -237,6 +237,6 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
// Definitions for Parameter Server. // Definitions for Parameter Server.
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_COMMON_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/consistent_hash_ring.h" #include "fl/server/consistent_hash_ring.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
bool ConsistentHashRing::Insert(uint32_t rank) { bool ConsistentHashRing::Insert(uint32_t rank) {
for (uint32_t i = 0; i < virtual_node_num_; i++) { for (uint32_t i = 0; i < virtual_node_num_; i++) {
@ -53,5 +53,5 @@ uint32_t ConsistentHashRing::Find(const std::string &key) {
return iterator->second; return iterator->second;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,15 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
#define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ #define MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
#include <map> #include <map>
#include <string> #include <string>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in // To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node. // server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
@ -59,6 +59,6 @@ class ConsistentHashRing {
std::map<size_t, uint32_t> ring_; std::map<size_t, uint32_t> ring_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_

View File

@ -20,19 +20,19 @@
#include <vector> #include <vector>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node, void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
uint32_t counting_server_rank) { uint32_t counting_server_rank) {
server_node_ = server_node; server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node_); MS_EXCEPTION_IF_NULL(server_node_);
local_rank_ = server_node_->rank_id(); local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num(); server_num_ = ps::PSContext::instance()->initial_server_num();
counting_server_rank_ = counting_server_rank; counting_server_rank_ = counting_server_rank;
return; return;
} }
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) { void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
communicator_ = communicator; communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_); MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack( communicator_->RegisterMsgCallBack(
@ -94,7 +94,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
report_count_req.set_id(id); report_count_req.set_id(id);
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount, if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
&report_cnt_rsp_msg)) { &report_cnt_rsp_msg)) {
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
return false; return false;
@ -126,7 +126,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_, if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
return false; return false;
} }
@ -165,7 +165,7 @@ bool DistributedCountService::ReInitForScaling() {
return true; return true;
} }
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) { void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -214,7 +214,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::Mes
return; return;
} }
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) { void DistributedCountService::HandleCountReachThresholdRequest(
const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -237,7 +238,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
return; return;
} }
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) { void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -290,7 +291,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
// Broadcast to all follower servers. // Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) { for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) { if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
return false; return false;
} }
@ -308,7 +309,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
// Broadcast to all follower servers. // Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) { for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) { if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
return false; return false;
} }
@ -318,5 +319,5 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
return true; return true;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ #define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#include <set> #include <set>
#include <string> #include <string>
@ -27,7 +27,7 @@
#include "ps/core/communicator/tcp_communicator.h" #include "ps/core/communicator/tcp_communicator.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
constexpr uint32_t kDefaultCountingServerRank = 0; constexpr uint32_t kDefaultCountingServerRank = 0;
constexpr auto kModuleDistributedCountService = "DistributedCountService"; constexpr auto kModuleDistributedCountService = "DistributedCountService";
@ -54,10 +54,10 @@ class DistributedCountService {
} }
// Initialize counter service with the server node because communication is needed. // Initialize counter service with the server node because communication is needed.
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank); void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node, uint32_t counting_server_rank);
// Register message callbacks of the counting server to handle messages sent by the other servers. // Register message callbacks of the counting server to handle messages sent by the other servers.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator); void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register counter to the counting server for the name with its threshold count in server cluster dimension and // Register counter to the counting server for the name with its threshold count in server cluster dimension and
// first/last count event callbacks. // first/last count event callbacks.
@ -87,15 +87,15 @@ class DistributedCountService {
DistributedCountService &operator=(const DistributedCountService &) = delete; DistributedCountService &operator=(const DistributedCountService &) = delete;
// Callback for the reporting count message from other servers. Only counting server will call this method. // Callback for the reporting count message from other servers. Only counting server will call this method.
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for the querying whether threshold count is reached message from other servers. Only counting // Callback for the querying whether threshold count is reached message from other servers. Only counting
// server will call this method. // server will call this method.
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for the first/last event message from the counting server. Only other servers will call this // Callback for the first/last event message from the counting server. Only other servers will call this
// method. // method.
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message); void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
// Call the callbacks when the first/last count event is triggered. // Call the callbacks when the first/last count event is triggered.
bool TriggerCounterEvent(const std::string &name); bool TriggerCounterEvent(const std::string &name);
@ -103,8 +103,8 @@ class DistributedCountService {
bool TriggerLastCountEvent(const std::string &name); bool TriggerLastCountEvent(const std::string &name);
// Members for the communication between counting server and other servers. // Members for the communication between counting server and other servers.
std::shared_ptr<core::ServerNode> server_node_; std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_; std::shared_ptr<ps::core::TcpCommunicator> communicator_;
uint32_t local_rank_; uint32_t local_rank_;
uint32_t server_num_; uint32_t server_num_;
@ -126,6 +126,6 @@ class DistributedCountService {
std::unordered_map<std::string, std::mutex> mutex_; std::unordered_map<std::string, std::mutex> mutex_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_

View File

@ -20,18 +20,18 @@
#include <vector> #include <vector>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) { void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
server_node_ = server_node; server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
local_rank_ = server_node_->rank_id(); local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num(); server_num_ = ps::PSContext::instance()->initial_server_num();
InitHashRing(); InitHashRing();
return; return;
} }
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) { void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
communicator_ = communicator; communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_); MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack( communicator_->RegisterMsgCallBack(
@ -100,7 +100,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
metadata_with_name.set_name(name); metadata_with_name.set_name(name);
*metadata_with_name.mutable_metadata() = meta; *metadata_with_name.mutable_metadata() = meta;
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata, if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) { &update_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
return false; return false;
@ -133,7 +133,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
PBMetadata get_metadata_rsp; PBMetadata get_metadata_rsp;
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata, if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
&get_meta_rsp_msg)) { &get_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
return get_metadata_rsp; return get_metadata_rsp;
@ -174,7 +174,7 @@ void DistributedMetadataStore::InitHashRing() {
return; return;
} }
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) { void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -196,7 +196,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
return; return;
} }
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) { void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -267,7 +267,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares(); auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
auto &fl_id = meta.pair_client_shares().fl_id(); auto &fl_id = meta.pair_client_shares().fl_id();
auto &client_shares = meta.pair_client_shares().client_shares(); auto &client_shares = meta.pair_client_shares().client_shares();
// google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter; // google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
// Check whether the new item already exists. // Check whether the new item already exists.
bool add_flag = true; bool add_flag = true;
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) { for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
@ -299,5 +299,5 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
return true; return true;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ #define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
#include <string> #include <string>
#include <memory> #include <memory>
@ -27,7 +27,7 @@
#include "fl/server/consistent_hash_ring.h" #include "fl/server/consistent_hash_ring.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore"; constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore";
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly // This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
@ -44,10 +44,10 @@ class DistributedMetadataStore {
} }
// Initialize metadata storage with the server node because communication is needed. // Initialize metadata storage with the server node because communication is needed.
void Initialize(const std::shared_ptr<core::ServerNode> &server_node); void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
// Register callbacks for the server to handle update/get metadata messages from other servers. // Register callbacks for the server to handle update/get metadata messages from other servers.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator); void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register metadata for the name with the initial value. This method should be only called once for each name. // Register metadata for the name with the initial value. This method should be only called once for each name.
void RegisterMetadata(const std::string &name, const PBMetadata &meta); void RegisterMetadata(const std::string &name, const PBMetadata &meta);
@ -65,7 +65,13 @@ class DistributedMetadataStore {
bool ReInitForScaling(); bool ReInitForScaling();
private: private:
DistributedMetadataStore() = default; DistributedMetadataStore()
: server_node_(nullptr),
communicator_(nullptr),
local_rank_(0),
server_num_(0),
router_(nullptr),
metadata_({}) {}
~DistributedMetadataStore() = default; ~DistributedMetadataStore() = default;
DistributedMetadataStore(const DistributedMetadataStore &) = delete; DistributedMetadataStore(const DistributedMetadataStore &) = delete;
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete; DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
@ -74,17 +80,17 @@ class DistributedMetadataStore {
void InitHashRing(); void InitHashRing();
// Callback for updating metadata request sent to the server. // Callback for updating metadata request sent to the server.
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for getting metadata request sent to the server. // Callback for getting metadata request sent to the server.
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Do updating metadata in the server where the metadata for the name is stored. // Do updating metadata in the server where the metadata for the name is stored.
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta); bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
// Members for the communication between servers. // Members for the communication between servers.
std::shared_ptr<core::ServerNode> server_node_; std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_; std::shared_ptr<ps::core::TcpCommunicator> communicator_;
uint32_t local_rank_; uint32_t local_rank_;
uint32_t server_num_; uint32_t server_num_;
@ -100,6 +106,6 @@ class DistributedMetadataStore {
std::unordered_map<std::string, std::mutex> mutex_; std::unordered_map<std::string, std::mutex> mutex_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_

View File

@ -21,7 +21,7 @@
#include <vector> #include <vector>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) { void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -320,5 +320,5 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
return true; return true;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
#define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ #define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
#include <map> #include <map>
#include <set> #include <set>
@ -31,7 +31,7 @@
#endif #endif
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
// logics relevant to kernel launching. // logics relevant to kernel launching.
@ -94,7 +94,7 @@ class Executor {
bool Unmask(); bool Unmask();
private: private:
Executor() {} Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}) {}
~Executor() = default; ~Executor() = default;
Executor(const Executor &) = delete; Executor(const Executor &) = delete;
Executor &operator=(const Executor &) = delete; Executor &operator=(const Executor &) = delete;
@ -126,6 +126,6 @@ class Executor {
#endif #endif
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_

View File

@ -23,10 +23,10 @@
#include "fl/server/server.h" #include "fl/server/server.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
class Server; class Server;
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) { void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator); MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator; communicator_ = communicator;
communicator_->RegisterMsgCallBack("syncIteration", communicator_->RegisterMsgCallBack("syncIteration",
@ -42,12 +42,12 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunica
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1)); std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
} }
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) { void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node; server_node_ = server_node;
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning), server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
std::bind(&Iteration::HandleIterationRunningEvent, this)); std::bind(&Iteration::HandleIterationRunningEvent, this));
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted), server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
std::bind(&Iteration::HandleIterationCompletedEvent, this)); std::bind(&Iteration::HandleIterationCompletedEvent, this));
} }
@ -56,7 +56,7 @@ void Iteration::AddRound(const std::shared_ptr<Round> &round) {
rounds_.push_back(round); rounds_.push_back(round);
} }
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) { const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
if (communicators.empty()) { if (communicators.empty()) {
MS_LOG(EXCEPTION) << "Communicators for rounds is empty."; MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
@ -64,7 +64,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
} }
std::for_each(communicators.begin(), communicators.end(), std::for_each(communicators.begin(), communicators.end(),
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) { [&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
for (auto &round : rounds_) { for (auto &round : rounds_) {
if (round == nullptr) { if (round == nullptr) {
continue; continue;
@ -120,7 +120,7 @@ void Iteration::SetIterationRunning() {
} }
if (server_node_->rank_id() == kLeaderServerRank) { if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state. // This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning)); server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
} }
iteration_state_ = IterationState::kRunning; iteration_state_ = IterationState::kRunning;
} }
@ -133,7 +133,7 @@ void Iteration::SetIterationCompleted() {
} }
if (server_node_->rank_id() == kLeaderServerRank) { if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state. // This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted)); server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
} }
iteration_state_ = IterationState::kCompleted; iteration_state_ = IterationState::kCompleted;
} }
@ -171,7 +171,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
sync_iter_req.set_rank(rank); sync_iter_req.set_rank(rank);
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration, if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
&sync_iter_rsp_msg)) { &sync_iter_rsp_msg)) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed."; MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
return false; return false;
@ -189,7 +189,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
return true; return true;
} }
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message) { void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -224,14 +224,14 @@ bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const s
notify_leader_to_next_iter_req.set_iter_num(iteration_num_); notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
notify_leader_to_next_iter_req.set_reason(reason); notify_leader_to_next_iter_req.set_reason(reason);
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank, if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
core::TcpUserCommand::kNotifyLeaderToNextIter)) { ps::core::TcpUserCommand::kNotifyLeaderToNextIter)) {
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed."; MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
return false; return false;
} }
return true; return true;
} }
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) { void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
return; return;
} }
@ -278,7 +278,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
std::vector<uint32_t> offline_servers = {}; std::vector<uint32_t> offline_servers = {};
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) { if (!communicator_->SendPbRequest(prepare_next_iter_req, i, ps::core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later."; MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
offline_servers.push_back(i); offline_servers.push_back(i);
continue; continue;
@ -289,17 +289,18 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) { std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
// Should avoid endless loop if the server communicator is stopped. // Should avoid endless loop if the server communicator is stopped.
while (communicator_->running() && while (communicator_->running() &&
!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) { !communicator_->SendPbRequest(prepare_next_iter_req, rank, ps::core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
<< " failed. The server has not recovered yet."; << " failed. The server has not recovered yet.";
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter)); std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
} }
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success."; MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
}); });
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
return true; return true;
} }
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) { void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
return; return;
} }
@ -329,7 +330,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
proceed_to_next_iter_req.set_last_iter_num(iteration_num_); proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
proceed_to_next_iter_req.set_reason(reason); proceed_to_next_iter_req.set_reason(reason);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) { if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, ps::core::TcpUserCommand::kProceedToNextIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed."; MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
continue; continue;
} }
@ -339,7 +340,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
return true; return true;
} }
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) { void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
return; return;
} }
@ -388,7 +389,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
EndLastIterRequest end_last_iter_req; EndLastIterRequest end_last_iter_req;
end_last_iter_req.set_last_iter_num(last_iter_num); end_last_iter_req.set_last_iter_num(last_iter_num);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) { for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) { if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed."; MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
continue; continue;
} }
@ -398,7 +399,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
return true; return true;
} }
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message) { void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
return; return;
} }
@ -429,9 +430,9 @@ void Iteration::EndLastIter() {
MS_LOG(INFO) << "End the last iteration " << iteration_num_; MS_LOG(INFO) << "End the last iteration " << iteration_num_;
iteration_num_++; iteration_num_++;
// After the job is done, reset the iteration to the initial number and reset ModelStore. // After the job is done, reset the iteration to the initial number and reset ModelStore.
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) { if (iteration_num_ > ps::PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_ MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
<< " is completed. Iteration number: " << PSContext::instance()->fl_iteration_num(); << " is completed. Iteration number: " << ps::PSContext::instance()->fl_iteration_num();
iteration_num_ = 1; iteration_num_ = 1;
iteration_loop_count_++; iteration_loop_count_++;
ModelStore::GetInstance().Reset(); ModelStore::GetInstance().Reset();
@ -444,5 +445,5 @@ void Iteration::EndLastIter() {
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ #define MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -26,7 +26,7 @@
#include "fl/server/local_meta_store.h" #include "fl/server/local_meta_store.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
enum class IterationState { enum class IterationState {
// This iteration is still in process. // This iteration is still in process.
@ -48,16 +48,16 @@ class Iteration {
} }
// Register callbacks for other servers to synchronize iteration information from leader server. // Register callbacks for other servers to synchronize iteration information from leader server.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator); void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register event callbacks for iteration state synchronization. // Register event callbacks for iteration state synchronization.
void RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node); void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
// Add a round for the iteration. This method will be called multiple times for each round. // Add a round for the iteration. This method will be called multiple times for each round.
void AddRound(const std::shared_ptr<Round> &round); void AddRound(const std::shared_ptr<Round> &round);
// Initialize all the rounds in the iteration. // Initialize all the rounds in the iteration.
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, void InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// This method will control servers to proceed to next iteration. // This method will control servers to proceed to next iteration.
@ -104,7 +104,7 @@ class Iteration {
// Synchronize iteration form the leader server(Rank 0). // Synchronize iteration form the leader server(Rank 0).
bool SyncIteration(uint32_t rank); bool SyncIteration(uint32_t rank);
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The request for moving to next iteration is not reentrant. // The request for moving to next iteration is not reentrant.
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num); bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
@ -112,28 +112,28 @@ class Iteration {
// The methods for moving to next iteration for all the servers. // The methods for moving to next iteration for all the servers.
// Step 1: follower servers notify leader server that they need to move to next iteration. // Step 1: follower servers notify leader server that they need to move to next iteration.
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason); bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode. // Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason); bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message); void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server prepare for the next iteration. This method will switch the server to safemode. // The server prepare for the next iteration. This method will switch the server to safemode.
void PrepareForNextIter(); void PrepareForNextIter();
// Step 3: leader server broadcast to all follower servers to move to next iteration. // Step 3: leader server broadcast to all follower servers to move to next iteration.
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason); bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Move to next iteration. Store last iterations model and reset all the rounds. // Move to next iteration. Store last iterations model and reset all the rounds.
void Next(bool is_iteration_valid, const std::string &reason); void Next(bool is_iteration_valid, const std::string &reason);
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode. // Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
bool BroadcastEndLastIterRequest(uint64_t iteration_num); bool BroadcastEndLastIterRequest(uint64_t iteration_num);
void HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message); void HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server end the last iteration. This method will increase the iteration number and cancel the safemode. // The server end the last iteration. This method will increase the iteration number and cancel the safemode.
void EndLastIter(); void EndLastIter();
std::shared_ptr<core::ServerNode> server_node_; std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_; std::shared_ptr<ps::core::TcpCommunicator> communicator_;
// All the rounds in the server. // All the rounds in the server.
std::vector<std::shared_ptr<Round>> rounds_; std::vector<std::shared_ptr<Round>> rounds_;
@ -155,6 +155,6 @@ class Iteration {
std::mutex pinned_mtx_; std::mutex pinned_mtx_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/iteration_timer.h" #include "fl/server/iteration_timer.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void IterationTimer::Start(const std::chrono::milliseconds &duration) { void IterationTimer::Start(const std::chrono::milliseconds &duration) {
if (running_.load()) { if (running_.load()) {
@ -52,5 +52,5 @@ bool IterationTimer::IsTimeOut(const std::chrono::milliseconds &timestamp) const
bool IterationTimer::IsRunning() const { return running_; } bool IterationTimer::IsRunning() const { return running_; }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ #define MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
#include <chrono> #include <chrono>
#include <atomic> #include <atomic>
@ -24,7 +24,7 @@
#include "fl/server/common.h" #include "fl/server/common.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration. // IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration.
class IterationTimer { class IterationTimer {
@ -59,6 +59,6 @@ class IterationTimer {
TimeOutCb timeout_callback_; TimeOutCb timeout_callback_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/params_info.h" #include "fl/server/kernel/params_info.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation. // AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
@ -99,6 +99,6 @@ class AggregationKernel : public CPUKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_

View File

@ -18,7 +18,7 @@
#include <utility> #include <utility>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
bool AggregationKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) { bool AggregationKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) {
@ -67,5 +67,5 @@ bool AggregationKernelFactory::Matched(const ParamsInfo &params_info, const CNod
} }
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -24,7 +24,7 @@
#include "fl/server/kernel/aggregation_kernel.h" #include "fl/server/kernel/aggregation_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>; using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
@ -51,6 +51,7 @@ class AggregationKernelRegister {
AggregationKernelCreator &&creator) { AggregationKernelCreator &&creator) {
AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
} }
~AggregationKernelRegister() = default;
}; };
// Register aggregation kernel with one template type T. // Register aggregation kernel with one template type T.
@ -66,6 +67,6 @@ class AggregationKernelRegister {
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); }); #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/apply_momentum_kernel.h" #include "fl/server/kernel/apply_momentum_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
REG_OPTIMIZER_KERNEL(ApplyMomentum, REG_OPTIMIZER_KERNEL(ApplyMomentum,
@ -30,5 +30,5 @@ REG_OPTIMIZER_KERNEL(ApplyMomentum,
ApplyMomentumKernel, float) ApplyMomentumKernel, float)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
@ -25,7 +25,7 @@
#include "fl/server/kernel/optimizer_kernel_factory.h" #include "fl/server/kernel/optimizer_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
using mindspore::kernel::ApplyMomentumCPUKernel; using mindspore::kernel::ApplyMomentumCPUKernel;
@ -57,6 +57,6 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/dense_grad_accum_kernel.h" #include "fl/server/kernel/dense_grad_accum_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
REG_AGGREGATION_KERNEL( REG_AGGREGATION_KERNEL(
@ -26,5 +26,5 @@ REG_AGGREGATION_KERNEL(
DenseGradAccumKernel, float) DenseGradAccumKernel, float)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/aggregation_kernel_factory.h" #include "fl/server/kernel/aggregation_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
@ -90,6 +90,6 @@ class DenseGradAccumKernel : public AggregationKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/fed_avg_kernel.h" #include "fl/server/kernel/fed_avg_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
REG_AGGREGATION_KERNEL_TWO(FedAvg, REG_AGGREGATION_KERNEL_TWO(FedAvg,
@ -29,5 +29,5 @@ REG_AGGREGATION_KERNEL_TWO(FedAvg,
FedAvgKernel, float, size_t) FedAvgKernel, float, size_t)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -31,7 +31,7 @@
#include "fl/server/kernel/aggregation_kernel_factory.h" #include "fl/server/kernel/aggregation_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from // The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
@ -42,7 +42,13 @@ namespace kernel {
template <typename T, typename S> template <typename T, typename S>
class FedAvgKernel : public AggregationKernel { class FedAvgKernel : public AggregationKernel {
public: public:
FedAvgKernel() : participated_(false) {} FedAvgKernel()
: cnode_weight_idx_(0),
weight_addr_(nullptr),
data_size_addr_(nullptr),
new_weight_addr_(nullptr),
new_data_size_addr_(nullptr),
participated_(false) {}
~FedAvgKernel() override = default; ~FedAvgKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override { void InitKernel(const CNodePtr &kernel_node) override {
@ -68,13 +74,13 @@ class FedAvgKernel : public AggregationKernel {
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first; AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
MS_EXCEPTION_IF_NULL(weight_node); MS_EXCEPTION_IF_NULL(weight_node);
name_ = cnode_name + "." + weight_node->fullname_with_scope(); name_ = cnode_name + "." + weight_node->fullname_with_scope();
first_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) { first_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
std::unique_lock<std::mutex> lock(weight_mutex_); std::unique_lock<std::mutex> lock(weight_mutex_);
if (!participated_) { if (!participated_) {
ClearWeightAndDataSize(); ClearWeightAndDataSize();
} }
}; };
last_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) { last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr); T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
size_t weight_size = weight_addr_->size; size_t weight_size = weight_addr_->size;
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr); S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
@ -193,7 +199,7 @@ class FedAvgKernel : public AggregationKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/params_info.h" #include "fl/server/kernel/params_info.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory // KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
@ -87,6 +87,6 @@ class KernelFactory {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -28,7 +28,7 @@
#include "fl/server/kernel/params_info.h" #include "fl/server/kernel/params_info.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
using mindspore::kernel::IsSameShape; using mindspore::kernel::IsSameShape;
@ -92,6 +92,6 @@ class OptimizerKernel : public CPUKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_

View File

@ -18,7 +18,7 @@
#include <utility> #include <utility>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) { bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) {
@ -66,5 +66,5 @@ bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodeP
} }
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -24,7 +24,7 @@
#include "fl/server/kernel/optimizer_kernel.h" #include "fl/server/kernel/optimizer_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>; using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>;
@ -50,6 +50,7 @@ class OptimizerKernelRegister {
OptimizerKernelRegister(const std::string &name, const ParamsInfo &params_info, OptimizerKernelCreator &&creator) { OptimizerKernelRegister(const std::string &name, const ParamsInfo &params_info, OptimizerKernelCreator &&creator) {
OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
} }
~OptimizerKernelRegister() = default;
}; };
// Register optimizer kernel with one template type T. // Register optimizer kernel with one template type T.
@ -59,6 +60,6 @@ class OptimizerKernelRegister {
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); }); #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_

View File

@ -18,7 +18,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) { ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) {
@ -64,5 +64,5 @@ const std::vector<std::string> &ParamsInfo::workspace_names() const { return wor
const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; } const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; }
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
#include <utility> #include <utility>
#include <string> #include <string>
@ -23,7 +23,7 @@
#include "ir/dtype/type_id.h" #include "ir/dtype/type_id.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc. // ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc.
@ -65,6 +65,6 @@ class ParamsInfo {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_

View File

@ -22,7 +22,7 @@
#include "schema/cipher_generated.h" #include "schema/cipher_generated.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void ClientListKernel::InitKernel(size_t) { void ClientListKernel::InitKernel(size_t) {
@ -150,7 +150,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "client_list_kernel success time is : " << duration; MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
return true; return true;
} // namespace ps } // namespace fl
bool ClientListKernel::Reset() { bool ClientListKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
@ -196,5 +196,5 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
REG_ROUND_KERNEL(getClientList, ClientListKernel) REG_ROUND_KERNEL(getClientList, ClientListKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
@ -26,7 +26,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class ClientListKernel : public RoundKernel { class ClientListKernel : public RoundKernel {
@ -50,6 +50,6 @@ class ClientListKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H

View File

@ -20,7 +20,7 @@
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void ExchangeKeysKernel::InitKernel(size_t) { void ExchangeKeysKernel::InitKernel(size_t) {
@ -100,5 +100,5 @@ bool ExchangeKeysKernel::Reset() {
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel) REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#include <vector> #include <vector>
#include "fl/server/common.h" #include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/armour/cipher/cipher_keys.h" #include "fl/armour/cipher/cipher_keys.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class ExchangeKeysKernel : public RoundKernel { class ExchangeKeysKernel : public RoundKernel {
@ -44,7 +44,7 @@ class ExchangeKeysKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H

View File

@ -19,7 +19,7 @@
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void GetKeysKernel::InitKernel(size_t) { void GetKeysKernel::InitKernel(size_t) {
@ -99,5 +99,5 @@ bool GetKeysKernel::Reset() {
REG_ROUND_KERNEL(getKeys, GetKeysKernel) REG_ROUND_KERNEL(getKeys, GetKeysKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
#include <vector> #include <vector>
#include "fl/server/common.h" #include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/armour/cipher/cipher_keys.h" #include "fl/armour/cipher/cipher_keys.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class GetKeysKernel : public RoundKernel { class GetKeysKernel : public RoundKernel {
@ -44,7 +44,7 @@ class GetKeysKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H

View File

@ -23,7 +23,7 @@
#include "fl/server/model_store.h" #include "fl/server/model_store.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void GetModelKernel::InitKernel(size_t) { void GetModelKernel::InitKernel(size_t) {
@ -133,5 +133,5 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
REG_ROUND_KERNEL(getModel, GetModelKernel) REG_ROUND_KERNEL(getModel, GetModelKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -27,13 +27,13 @@
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
constexpr uint32_t kPrintGetModelForEveryRetryTime = 50; constexpr uint32_t kPrintGetModelForEveryRetryTime = 50;
class GetModelKernel : public RoundKernel { class GetModelKernel : public RoundKernel {
public: public:
GetModelKernel() = default; GetModelKernel() : executor_(nullptr), iteration_time_window_(0), retry_count_(0) {}
~GetModelKernel() override = default; ~GetModelKernel() override = default;
void InitKernel(size_t) override; void InitKernel(size_t) override;
@ -58,6 +58,6 @@ class GetModelKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_

View File

@ -21,7 +21,7 @@
#include "fl/armour/cipher/cipher_shares.h" #include "fl/armour/cipher/cipher_shares.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void GetSecretsKernel::InitKernel(size_t) { void GetSecretsKernel::InitKernel(size_t) {
@ -102,5 +102,5 @@ bool GetSecretsKernel::Reset() {
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel) REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#include <vector> #include <vector>
#include "fl/server/common.h" #include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class GetSecretsKernel : public RoundKernel { class GetSecretsKernel : public RoundKernel {
@ -44,7 +44,7 @@ class GetSecretsKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H

View File

@ -22,7 +22,7 @@
#include "fl/server/model_store.h" #include "fl/server/model_store.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void PullWeightKernel::InitKernel(size_t) { void PullWeightKernel::InitKernel(size_t) {
@ -137,5 +137,5 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
REG_ROUND_KERNEL(pullWeight, PullWeightKernel) REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500; constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500;
@ -53,6 +53,6 @@ class PullWeightKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/round/push_weight_kernel.h" #include "fl/server/kernel/round/push_weight_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void PushWeightKernel::InitKernel(size_t) { void PushWeightKernel::InitKernel(size_t) {
@ -60,8 +60,8 @@ bool PushWeightKernel::Reset() {
return true; return true;
} }
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) { void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (PSContext::instance()->resetter_round() == ResetterRound::kPushWeight) { if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
FinishIteration(); FinishIteration();
} }
return; return;
@ -136,5 +136,5 @@ void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const
REG_ROUND_KERNEL(pushWeight, PushWeightKernel) REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class PushWeightKernel : public RoundKernel { class PushWeightKernel : public RoundKernel {
@ -39,7 +39,7 @@ class PushWeightKernel : public RoundKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs); const std::vector<AddressPtr> &outputs);
bool Reset() override; bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override; void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private: private:
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req); bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
@ -52,6 +52,6 @@ class PushWeightKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_

View File

@ -20,7 +20,7 @@
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) { void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
@ -34,17 +34,17 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return; return;
} }
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
MS_LOG(INFO) << "start FinishIteration"; MS_LOG(INFO) << "start FinishIteration";
FinishIteration(); FinishIteration();
MS_LOG(INFO) << "end FinishIteration"; MS_LOG(INFO) << "end FinishIteration";
return; return;
}; };
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { return; }; auto first_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { return; };
name_unmask_ = "UnMaskKernel"; name_unmask_ = "UnMaskKernel";
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : " MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
<< LocalMetaStore::GetInstance().curr_iter_num(); << LocalMetaStore::GetInstance().curr_iter_num();
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(), DistributedCountService::GetInstance().RegisterCounter(name_unmask_, ps::PSContext::instance()->initial_server_num(),
{first_cnt_handler, last_cnt_handler}); {first_cnt_handler, last_cnt_handler});
} }
@ -134,9 +134,9 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
return true; return true;
} }
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
if (PSContext::instance()->encrypt_type() == kPWEncryptType) { if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
while (!Executor::GetInstance().IsAllWeightAggregationDone()) { while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
@ -164,5 +164,5 @@ bool ReconstructSecretsKernel::Reset() {
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel) REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class ReconstructSecretsKernel : public RoundKernel { class ReconstructSecretsKernel : public RoundKernel {
@ -39,7 +39,7 @@ class ReconstructSecretsKernel : public RoundKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
bool Reset() override; bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override; void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private: private:
std::string name_unmask_; std::string name_unmask_;
@ -49,6 +49,6 @@ class ReconstructSecretsKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_

View File

@ -24,7 +24,7 @@
#include <vector> #include <vector>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) { RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() {
} }
} }
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { return; } void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) { return; } void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
void RoundKernel::StopTimer() const { void RoundKernel::StopTimer() const {
if (stop_timer_cb_) { if (stop_timer_cb_) {
@ -129,5 +129,5 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
} }
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -35,7 +35,7 @@
#include "fl/server/distributed_metadata_store.h" #include "fl/server/distributed_metadata_store.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
@ -67,8 +67,8 @@ class RoundKernel : virtual public CPUKernel {
// The counter event handlers for DistributedCountService. // The counter event handlers for DistributedCountService.
// The callbacks when first message and last message for this round kernel is received. // The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by counting server. // These methods is called by class DistributedCountService and triggered by counting server.
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); virtual void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); virtual void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
// Called when this round is finished. This round timer's Stop method will be called. // Called when this round is finished. This round timer's Stop method will be called.
void StopTimer() const; void StopTimer() const;
@ -123,6 +123,6 @@ class RoundKernel : virtual public CPUKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
RoundKernelFactory &RoundKernelFactory::GetInstance() { RoundKernelFactory &RoundKernelFactory::GetInstance() {
@ -40,5 +40,5 @@ std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name)
} }
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -25,7 +25,7 @@
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>; using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
@ -50,6 +50,7 @@ class RoundKernelRegister {
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) { RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
RoundKernelFactory::GetInstance().Register(name, std::move(creator)); RoundKernelFactory::GetInstance().Register(name, std::move(creator));
} }
~RoundKernelRegister() = default;
}; };
#define REG_ROUND_KERNEL(NAME, CLASS) \ #define REG_ROUND_KERNEL(NAME, CLASS) \
@ -57,6 +58,6 @@ class RoundKernelRegister {
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); }); static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_

View File

@ -19,7 +19,7 @@
#include <memory> #include <memory>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void ShareSecretsKernel::InitKernel(size_t) { void ShareSecretsKernel::InitKernel(size_t) {
@ -101,5 +101,5 @@ bool ShareSecretsKernel::Reset() {
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel) REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#include <vector> #include <vector>
#include "fl/server/common.h" #include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/armour/cipher/cipher_shares.h" #include "fl/armour/cipher/cipher_shares.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class ShareSecretsKernel : public RoundKernel { class ShareSecretsKernel : public RoundKernel {
@ -44,7 +44,7 @@ class ShareSecretsKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H

View File

@ -26,7 +26,7 @@
#endif #endif
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void StartFLJobKernel::InitKernel(size_t) { void StartFLJobKernel::InitKernel(size_t) {
@ -113,8 +113,8 @@ bool StartFLJobKernel::Reset() {
return true; return true;
} }
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_; iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_); LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
// The first startFLJob request means a new iteration starts running. // The first startFLJob request means a new iteration starts running.
Iteration::GetInstance().SetIterationRunning(); Iteration::GetInstance().SetIterationRunning();
@ -194,8 +194,8 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
std::map<std::string, AddressPtr> feature_maps) { std::map<std::string, AddressPtr> feature_maps) {
auto fbs_reason = fbb->CreateString(reason); auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time); auto fbs_next_req_time = fbb->CreateString(next_req_time);
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode()); auto fbs_server_mode = fbb->CreateString(ps::PSContext::instance()->server_mode());
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name()); auto fbs_fl_name = fbb->CreateString(ps::PSContext::instance()->fl_name());
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
auto *param = armour::CipherInit::GetInstance().GetPublicParams(); auto *param = armour::CipherInit::GetInstance().GetPublicParams();
@ -206,7 +206,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
float dp_eps = param->dp_eps; float dp_eps = param->dp_eps;
float dp_delta = param->dp_delta; float dp_delta = param->dp_delta;
float dp_norm_clip = param->dp_norm_clip; float dp_norm_clip = param->dp_norm_clip;
auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type()); auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type());
auto cipher_public_params = auto cipher_public_params =
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type); schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
@ -215,10 +215,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name); fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_server_mode(fbs_server_mode); fl_plan_builder.add_server_mode(fbs_server_mode);
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num()); fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num()); fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size()); fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size());
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate()); fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
fl_plan_builder.add_cipher(cipher_public_params); fl_plan_builder.add_cipher(cipher_public_params);
#endif #endif
@ -250,5 +250,5 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel) REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -27,7 +27,7 @@
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
class StartFLJobKernel : public RoundKernel { class StartFLJobKernel : public RoundKernel {
@ -40,7 +40,7 @@ class StartFLJobKernel : public RoundKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
bool Reset() override; bool Reset() override;
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override; void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private: private:
// Returns whether the startFLJob count of this iteration has reached the threshold. // Returns whether the startFLJob count of this iteration has reached the threshold.
@ -74,6 +74,6 @@ class StartFLJobKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_

View File

@ -21,7 +21,7 @@
#include "fl/server/kernel/round/update_model_kernel.h" #include "fl/server/kernel/round/update_model_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
void UpdateModelKernel::InitKernel(size_t threshold_count) { void UpdateModelKernel::InitKernel(size_t threshold_count) {
@ -87,8 +87,8 @@ bool UpdateModelKernel::Reset() {
return true; return true;
} }
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) { void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) { if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
while (!executor_->IsAllWeightAggregationDone()) { while (!executor_->IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize); size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size; << total_data_size;
if (PSContext::instance()->encrypt_type() != kPWEncryptType) { if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
FinishIteration(); FinishIteration();
} }
} }
@ -226,5 +226,5 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fb
REG_ROUND_KERNEL(updateModel, UpdateModelKernel) REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round. // The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
@ -44,7 +44,7 @@ class UpdateModelKernel : public RoundKernel {
bool Reset() override; bool Reset() override;
// In some cases, the last updateModel message means this server iteration is finished. // In some cases, the last updateModel message means this server iteration is finished.
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override; void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private: private:
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb); bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
@ -62,6 +62,6 @@ class UpdateModelKernel : public RoundKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/local_meta_store.h" #include "fl/server/local_meta_store.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void LocalMetaStore::remove_value(const std::string &name) { void LocalMetaStore::remove_value(const std::string &name) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
@ -41,5 +41,5 @@ const size_t LocalMetaStore::curr_iter_num() {
return curr_iter_num_; return curr_iter_num_;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ #define MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
#include <any> #include <any>
#include <mutex> #include <mutex>
@ -24,7 +24,7 @@
#include "fl/server/common.h" #include "fl/server/common.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// LocalMetaStore class is used for metadata storage of this server process. // LocalMetaStore class is used for metadata storage of this server process.
// For example, the current iteration number, time windows for round kernels, etc. // For example, the current iteration number, time windows for round kernels, etc.
@ -71,7 +71,7 @@ class LocalMetaStore {
const size_t curr_iter_num(); const size_t curr_iter_num();
private: private:
LocalMetaStore() = default; LocalMetaStore() : key_to_meta_({}), curr_iter_num_(0) {}
~LocalMetaStore() = default; ~LocalMetaStore() = default;
LocalMetaStore(const LocalMetaStore &) = delete; LocalMetaStore(const LocalMetaStore &) = delete;
LocalMetaStore &operator=(const LocalMetaStore &) = delete; LocalMetaStore &operator=(const LocalMetaStore &) = delete;
@ -83,6 +83,6 @@ class LocalMetaStore {
size_t curr_iter_num_; size_t curr_iter_num_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_

View File

@ -18,7 +18,7 @@
#include <utility> #include <utility>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) { void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
addresses_.try_emplace(name, address); addresses_.try_emplace(name, address);
@ -32,5 +32,5 @@ void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64
void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); } void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
#define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ #define MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
#include <map> #include <map>
#include <string> #include <string>
@ -26,7 +26,7 @@
#include "fl/server/common.h" #include "fl/server/common.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc. // Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc.
// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights", // MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights",
@ -88,6 +88,6 @@ class MemoryRegister {
} }
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_

View File

@ -21,7 +21,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
void ModelStore::Initialize(uint32_t max_count) { void ModelStore::Initialize(uint32_t max_count) {
if (!Executor::GetInstance().initialized()) { if (!Executor::GetInstance().initialized()) {
@ -155,5 +155,5 @@ size_t ModelStore::ComputeModelSize() {
return model_size; return model_size;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ #define MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -25,7 +25,7 @@
#include "fl/server/executor.h" #include "fl/server/executor.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// The initial iteration number is 0 in server. // The initial iteration number is 0 in server.
constexpr size_t kInitIterationNum = 0; constexpr size_t kInitIterationNum = 0;
@ -84,6 +84,6 @@ class ModelStore {
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_; std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_

View File

@ -23,7 +23,7 @@
#include <algorithm> #include <algorithm>
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) { bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
@ -199,8 +199,8 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
} }
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
if (PSContext::instance()->server_mode() == kServerModeFL || if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) { ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
return true; return true;
} }
@ -321,13 +321,13 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) { std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
std::vector<std::string> aggregation_algorithm = {}; std::vector<std::string> aggregation_algorithm = {};
if (PSContext::instance()->server_mode() == kServerModeFL || if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) { ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
aggregation_algorithm.push_back("FedAvg"); aggregation_algorithm.push_back("FedAvg");
} else if (PSContext::instance()->server_mode() == kServerModePS) { } else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
aggregation_algorithm.push_back("DenseGradAccum"); aggregation_algorithm.push_back("DenseGradAccum");
} else { } else {
MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode(); MS_LOG(ERROR) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
} }
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
@ -344,5 +344,5 @@ template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::Aggregat
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
std::shared_ptr<MemoryRegister> memory_register); std::shared_ptr<MemoryRegister> memory_register);
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
#define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ #define MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
#include <map> #include <map>
#include <memory> #include <memory>
@ -28,7 +28,7 @@
#include "fl/server/kernel/optimizer_kernel_factory.h" #include "fl/server/kernel/optimizer_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server // Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
// kernels. // kernels.
@ -137,6 +137,6 @@ class ParameterAggregator {
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_; std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_

View File

@ -21,7 +21,7 @@
#include "fl/server/iteration.h" #include "fl/server/iteration.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
class Server; class Server;
class Iteration; class Iteration;
@ -34,14 +34,14 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
threshold_count_(threshold_count), threshold_count_(threshold_count),
server_num_as_threshold_(server_num_as_threshold) {} server_num_as_threshold_(server_num_as_threshold) {}
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb) { FinishIterCb finish_iteration_cb) {
MS_EXCEPTION_IF_NULL(communicator); MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator; communicator_ = communicator;
// Register callback for round kernel. // Register callback for round kernel.
communicator_->RegisterMsgCallBack( communicator_->RegisterMsgCallBack(
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); }); name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
// Callback when the iteration is finished. // Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void { finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
@ -106,7 +106,7 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
return; return;
} }
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) { void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) { if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr."; MS_LOG(ERROR) << "Message is nullptr.";
return; return;
@ -152,7 +152,7 @@ bool Round::check_timeout() const { return check_timeout_; }
size_t Round::time_window() const { return time_window_; } size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
// The timer starts only after the first count event is triggered by DistributedCountService. // The timer starts only after the first count event is triggered by DistributedCountService.
if (check_timeout_) { if (check_timeout_) {
@ -164,7 +164,7 @@ void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &messa
return; return;
} }
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered."; MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
// Same as the first count event, the timer must be stopped by DistributedCountService. // Same as the first count event, the timer must be stopped by DistributedCountService.
if (check_timeout_) { if (check_timeout_) {
@ -176,5 +176,5 @@ void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &messag
return; return;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ #define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of // Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting // multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
@ -37,7 +37,7 @@ class Round {
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false); bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
~Round() = default; ~Round() = default;
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb); FinishIterCb finish_iteration_cb);
// Reinitialize count service and round kernel of this round after scaling operations are done. // Reinitialize count service and round kernel of this round after scaling operations are done.
@ -48,7 +48,7 @@ class Round {
// This method is the callback which will be set to the communicator and called after the corresponding round message // This method is the callback which will be set to the communicator and called after the corresponding round message
// is sent to the server. // is sent to the server.
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message); void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message);
// Round needs to be reset after each iteration is finished or its timer expires. // Round needs to be reset after each iteration is finished or its timer expires.
void Reset(); void Reset();
@ -60,8 +60,8 @@ class Round {
private: private:
// The callbacks which will be set to DistributedCounterService. // The callbacks which will be set to DistributedCounterService.
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
std::string name_; std::string name_;
@ -83,7 +83,7 @@ class Round {
// Whether this round uses the server number as its threshold count. // Whether this round uses the server number as its threshold count.
bool server_num_as_threshold_; bool server_num_as_threshold_;
std::shared_ptr<core::CommunicatorBase> communicator_; std::shared_ptr<ps::core::CommunicatorBase> communicator_;
// The round kernel for this Round. // The round kernel for this Round.
std::shared_ptr<kernel::RoundKernel> kernel_; std::shared_ptr<kernel::RoundKernel> kernel_;
@ -97,6 +97,6 @@ class Round {
FinalizeCb finalize_cb_; FinalizeCb finalize_cb_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_ROUND_H_

View File

@ -30,17 +30,8 @@
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
// This function is for the exit of server process when an interrupt signal is captured.
void SignalHandler(int signal) {
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
return;
}
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config, void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) { const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -76,7 +67,6 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization: // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
// InitCipher---->InitExecutor // InitCipher---->InitExecutor
void Server::Run() { void Server::Run() {
signal(SIGINT, SignalHandler);
std::unique_lock<std::mutex> lock(scaling_mtx_); std::unique_lock<std::mutex> lock(scaling_mtx_);
InitServerContext(); InitServerContext();
InitCluster(); InitCluster();
@ -84,8 +74,8 @@ void Server::Run() {
RegisterCommCallbacks(); RegisterCommCallbacks();
StartCommunicator(); StartCommunicator();
InitExecutor(); InitExecutor();
std::string encrypt_type = PSContext::instance()->encrypt_type(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type != kNotEncryptType) { if (encrypt_type != ps::kNotEncryptType) {
InitCipher(); InitCipher();
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated."; MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
} }
@ -96,7 +86,7 @@ void Server::Run() {
// Wait communicators to stop so the main thread is blocked. // Wait communicators to stop so the main thread is blocked.
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); }); [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Join(); });
communicator_with_server_->Join(); communicator_with_server_->Join();
MsException::Instance().CheckException(); MsException::Instance().CheckException();
return; return;
@ -115,18 +105,18 @@ void Server::CancelSafeMode() {
bool Server::IsSafeMode() { return safemode_.load(); } bool Server::IsSafeMode() { return safemode_.load(); }
void Server::InitServerContext() { void Server::InitServerContext() {
PSContext::instance()->GenerateResetterRound(); ps::PSContext::instance()->GenerateResetterRound();
scheduler_ip_ = PSContext::instance()->scheduler_host(); scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
scheduler_port_ = PSContext::instance()->scheduler_port(); scheduler_port_ = ps::PSContext::instance()->scheduler_port();
worker_num_ = PSContext::instance()->initial_worker_num(); worker_num_ = ps::PSContext::instance()->initial_worker_num();
server_num_ = PSContext::instance()->initial_server_num(); server_num_ = ps::PSContext::instance()->initial_server_num();
return; return;
} }
void Server::InitCluster() { void Server::InitCluster() {
server_node_ = std::make_shared<core::ServerNode>(); server_node_ = std::make_shared<ps::core::ServerNode>();
MS_EXCEPTION_IF_NULL(server_node_); MS_EXCEPTION_IF_NULL(server_node_);
task_executor_ = std::make_shared<core::TaskExecutor>(32); task_executor_ = std::make_shared<ps::core::TaskExecutor>(32);
MS_EXCEPTION_IF_NULL(task_executor_); MS_EXCEPTION_IF_NULL(task_executor_);
if (!InitCommunicatorWithServer()) { if (!InitCommunicatorWithServer()) {
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed."; MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
@ -136,7 +126,6 @@ void Server::InitCluster() {
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed."; MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
return; return;
} }
global_worker_server_comms = communicators_with_worker_;
return; return;
} }
@ -187,8 +176,8 @@ void Server::InitIteration() {
} }
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
std::string encrypt_type = PSContext::instance()->encrypt_type(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == kPWEncryptType) { if (encrypt_type == ps::kPWEncryptType) {
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count; cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0; cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio; cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
@ -245,10 +234,10 @@ void Server::InitCipher() {
unsigned char cipher_p[SECRET_MAX_LEN] = {0}; unsigned char cipher_p[SECRET_MAX_LEN] = {0};
int cipher_g = 1; int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0}; unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
float dp_eps = PSContext::instance()->dp_eps(); float dp_eps = ps::PSContext::instance()->dp_eps();
float dp_delta = PSContext::instance()->dp_delta(); float dp_delta = ps::PSContext::instance()->dp_delta();
float dp_norm_clip = PSContext::instance()->dp_norm_clip(); float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
std::string encrypt_type = PSContext::instance()->encrypt_type(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
mpz_t prim; mpz_t prim;
mpz_init(prim); mpz_init(prim);
@ -276,7 +265,7 @@ void Server::RegisterCommCallbacks() {
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
// rounds' callbacks. // rounds' callbacks.
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
MS_EXCEPTION_IF_NULL(tcp_comm); MS_EXCEPTION_IF_NULL(tcp_comm);
// Set message callbacks for server-to-server communication. // Set message callbacks for server-to-server communication.
@ -304,23 +293,23 @@ void Server::RegisterCommCallbacks() {
std::bind(&Server::ProcessAfterScalingIn, this)); std::bind(&Server::ProcessAfterScalingIn, this));
} }
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) { void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator); MS_EXCEPTION_IF_NULL(communicator);
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() { communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed."; MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
safemode_ = true; safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop(); communicator_with_server_->Stop();
}); });
communicator->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [&]() { communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
MS_LOG(ERROR) MS_LOG(ERROR)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the " << "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase."; "network building phase.";
safemode_ = true; safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop(); communicator_with_server_->Stop();
}); });
} }
@ -377,7 +366,7 @@ void Server::StartCommunicator() {
MS_LOG(INFO) << "Start communicator with worker."; MS_LOG(INFO) << "Start communicator with worker.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); }); [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); });
} }
void Server::ProcessBeforeScalingOut() { void Server::ProcessBeforeScalingOut() {
@ -424,7 +413,7 @@ void Server::ProcessAfterScalingIn() {
if (server_node_->rank_id() == UINT32_MAX) { if (server_node_->rank_id() == UINT32_MAX) {
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting."; MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(), std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); }); [](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop(); communicator_with_server_->Stop();
return; return;
} }
@ -449,5 +438,5 @@ void Server::ProcessAfterScalingIn() {
safemode_ = false; safemode_ = false;
} }
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ #ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ #define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -31,7 +31,7 @@
#endif #endif
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace server { namespace server {
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning. // Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
class Server { class Server {
@ -90,7 +90,7 @@ class Server {
void RegisterCommCallbacks(); void RegisterCommCallbacks();
// Register cluster exception callbacks. This method is called in RegisterCommCallbacks. // Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
void RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator); void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Initialize executor according to the server mode. // Initialize executor according to the server mode.
void InitExecutor(); void InitExecutor();
@ -113,11 +113,11 @@ class Server {
void ProcessAfterScalingIn(); void ProcessAfterScalingIn();
// The server node is initialized in Server. // The server node is initialized in Server.
std::shared_ptr<core::ServerNode> server_node_; std::shared_ptr<ps::core::ServerNode> server_node_;
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks // The task executor of the communicators. This helps server to handle network message concurrently. The tasks
// submitted to this task executor is asynchronous. // submitted to this task executor is asynchronous.
std::shared_ptr<core::TaskExecutor> task_executor_; std::shared_ptr<ps::core::TaskExecutor> task_executor_;
// Which protocol should communicators use. // Which protocol should communicators use.
bool use_tcp_; bool use_tcp_;
@ -136,12 +136,12 @@ class Server {
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective // Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
// operations, etc. // operations, etc.
std::shared_ptr<core::CommunicatorBase> communicator_with_server_; std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP. // The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
// In some cases, both types should be supported in one distributed training job. So here we may have multiple // In some cases, both types should be supported in one distributed training job. So here we may have multiple
// communicators. // communicators.
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_; std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events. // Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
std::mutex scaling_mtx_; std::mutex scaling_mtx_;
@ -176,6 +176,6 @@ class Server {
float percent_for_get_model_; float percent_for_get_model_;
}; };
} // namespace server } // namespace server
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_ #endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_

View File

@ -22,27 +22,27 @@
#include "utils/ms_exception.h" #include "utils/ms_exception.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
namespace worker { namespace worker {
void FLWorker::Run() { void FLWorker::Run() {
worker_num_ = PSContext::instance()->worker_num(); worker_num_ = ps::PSContext::instance()->worker_num();
server_num_ = PSContext::instance()->server_num(); server_num_ = ps::PSContext::instance()->server_num();
scheduler_ip_ = PSContext::instance()->scheduler_ip(); scheduler_ip_ = ps::PSContext::instance()->scheduler_ip();
scheduler_port_ = PSContext::instance()->scheduler_port(); scheduler_port_ = ps::PSContext::instance()->scheduler_port();
worker_step_num_per_iteration_ = PSContext::instance()->worker_step_num_per_iteration(); worker_step_num_per_iteration_ = ps::PSContext::instance()->worker_step_num_per_iteration();
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_; ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
PSContext::instance()->cluster_config().scheduler_port = scheduler_port_; ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
PSContext::instance()->cluster_config().initial_worker_num = worker_num_; ps::PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
PSContext::instance()->cluster_config().initial_server_num = server_num_; ps::PSContext::instance()->cluster_config().initial_server_num = server_num_;
MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_ MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
<< ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_ << ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
<< ", Scheduler port:" << scheduler_port_ << ", Scheduler port:" << scheduler_port_
<< ", Worker training step per iteration:" << worker_step_num_per_iteration_; << ", Worker training step per iteration:" << worker_step_num_per_iteration_;
worker_node_ = std::make_shared<core::WorkerNode>(); worker_node_ = std::make_shared<ps::core::WorkerNode>();
MS_EXCEPTION_IF_NULL(worker_node_); MS_EXCEPTION_IF_NULL(worker_node_);
worker_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() { worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
Finalize(); Finalize();
try { try {
MS_LOG(EXCEPTION) MS_LOG(EXCEPTION)
@ -51,7 +51,7 @@ void FLWorker::Run() {
MsException::Instance().SetException(); MsException::Instance().SetException();
} }
}); });
worker_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() { worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
Finalize(); Finalize();
try { try {
MS_LOG(EXCEPTION) MS_LOG(EXCEPTION)
@ -74,7 +74,7 @@ void FLWorker::Finalize() {
worker_node_->Stop(); worker_node_->Stop();
} }
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, core::TcpUserCommand command, bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output) { std::shared_ptr<std::vector<unsigned char>> *output) {
// If the worker is in safemode, do not communicate with server. // If the worker is in safemode, do not communicate with server.
while (safemode_.load()) { while (safemode_.load()) {
@ -97,7 +97,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
if (output != nullptr) { if (output != nullptr) {
while (true) { while (true) {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) { if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
output)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed."; MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false; return false;
} }
@ -106,7 +107,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
return false; return false;
} }
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode) { if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == ps::kClusterSafeMode) {
MS_LOG(INFO) << "The server " << server_rank << " is in safemode."; MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode)); std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
} else { } else {
@ -114,7 +115,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
} }
} }
} else { } else {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) { if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed."; MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false; return false;
} }
@ -155,9 +156,9 @@ void FLWorker::InitializeFollowerScaler() {
std::bind(&FLWorker::ProcessAfterScalingOut, this)); std::bind(&FLWorker::ProcessAfterScalingOut, this));
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline", worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
std::bind(&FLWorker::ProcessAfterScalingIn, this)); std::bind(&FLWorker::ProcessAfterScalingIn, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning), worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
std::bind(&FLWorker::HandleIterationRunningEvent, this)); std::bind(&FLWorker::HandleIterationRunningEvent, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted), worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
std::bind(&FLWorker::HandleIterationCompletedEvent, this)); std::bind(&FLWorker::HandleIterationCompletedEvent, this));
} }
@ -222,5 +223,5 @@ void FLWorker::ProcessAfterScalingIn() {
safemode_ = false; safemode_ = false;
} }
} // namespace worker } // namespace worker
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_ #ifndef MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
#define MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_ #define MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -28,7 +28,7 @@
#include "ps/core/communicator/tcp_communicator.h" #include "ps/core/communicator/tcp_communicator.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace fl {
using FBBuilder = flatbuffers::FlatBufferBuilder; using FBBuilder = flatbuffers::FlatBufferBuilder;
// The step number for worker to judge whether to communicate with server. // The step number for worker to judge whether to communicate with server.
@ -59,7 +59,7 @@ class FLWorker {
} }
void Run(); void Run();
void Finalize(); void Finalize();
bool SendToServer(uint32_t server_rank, const void *data, size_t size, core::TcpUserCommand command, bool SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output = nullptr); std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
uint32_t server_num() const; uint32_t server_num() const;
@ -104,7 +104,7 @@ class FLWorker {
uint32_t worker_num_; uint32_t worker_num_;
std::string scheduler_ip_; std::string scheduler_ip_;
uint16_t scheduler_port_; uint16_t scheduler_port_;
std::shared_ptr<core::WorkerNode> worker_node_; std::shared_ptr<ps::core::WorkerNode> worker_node_;
// The worker standalone training step number before communicating with server. This used in hybrid training mode. // The worker standalone training step number before communicating with server. This used in hybrid training mode.
uint64_t worker_step_num_per_iteration_; uint64_t worker_step_num_per_iteration_;
@ -121,6 +121,6 @@ class FLWorker {
std::atomic_bool safemode_; std::atomic_bool safemode_;
}; };
} // namespace worker } // namespace worker
} // namespace ps } // namespace fl
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_ #endif // MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_

View File

@ -639,7 +639,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) {
return true; return true;
} }
bool StartFLWorkerAction(const ResourcePtr &) { bool StartFLWorkerAction(const ResourcePtr &) {
ps::worker::FLWorker::GetInstance().Run(); fl::worker::FLWorker::GetInstance().Run();
return true; return true;
} }
@ -665,7 +665,7 @@ bool StartServerAction(const ResourcePtr &res) {
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window(); uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window(); uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
std::vector<ps::server::RoundConfig> rounds_config = { std::vector<fl::server::RoundConfig> rounds_config = {
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold}, {"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
{"updateModel", true, update_model_time_window, true, update_model_threshold}, {"updateModel", true, update_model_time_window, true, update_model_threshold},
{"getModel"}, {"getModel"},
@ -676,22 +676,22 @@ bool StartServerAction(const ResourcePtr &res) {
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window(); uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold(); size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold}; fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
size_t executor_threshold = 0; size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
executor_threshold = update_model_threshold; executor_threshold = update_model_threshold;
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph, fl::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
executor_threshold); executor_threshold);
} else if (server_mode_ == ps::kServerModePS) { } else if (server_mode_ == ps::kServerModePS) {
executor_threshold = worker_num; executor_threshold = worker_num;
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph, fl::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
executor_threshold); executor_threshold);
} else { } else {
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported."; MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
return false; return false;
} }
ps::server::Server::GetInstance().Run(); fl::server::Server::GetInstance().Run();
return true; return true;
} }

View File

@ -1293,7 +1293,7 @@ void ClearResAtexit() {
MS_LOG(INFO) << "Start finalizing worker."; MS_LOG(INFO) << "Start finalizing worker.";
const std::string &server_mode = ps::PSContext::instance()->server_mode(); const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) { if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
ps::worker::FLWorker::GetInstance().Finalize(); fl::worker::FLWorker::GetInstance().Finalize();
} else { } else {
ps::Worker::GetInstance().Finalize(); ps::Worker::GetInstance().Finalize();
} }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
syntax = "proto3"; syntax = "proto3";
package mindspore.ps; package mindspore.fl;
message CollectiveData { message CollectiveData {
bytes data = 1; bytes data = 1;

View File

@ -286,8 +286,9 @@ void PSContext::GenerateResetterRound() {
return; return;
} }
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3); ((unsigned int)is_federated_learning_mode << 1) |
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)secure_aggregation_ << 3);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset; resetter_round_ = ResetterRound::kNoNeedToReset;
} else { } else {